drbh commited on
Commit
dcfa38d
Β·
1 Parent(s): 8176cbe

feat: bump build for torch compile

Browse files
Files changed (18) hide show
  1. build/torch27-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} +1 -1
  2. build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  3. build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py +258 -34
  4. build/torch27-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} +1 -1
  5. build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  6. build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py +258 -34
  7. build/torch27-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} +1 -1
  8. build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +3 -3
  9. build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py +258 -34
  10. build/torch28-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} +1 -1
  11. build/torch28-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  12. build/torch28-cxx11-cu126-x86_64-linux/megablocks/layers.py +258 -34
  13. build/torch28-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} +1 -1
  14. build/torch28-cxx11-cu128-x86_64-linux/megablocks/_ops.py +3 -3
  15. build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py +258 -34
  16. build/torch28-cxx11-cu129-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} +1 -1
  17. build/torch28-cxx11-cu129-x86_64-linux/megablocks/_ops.py +3 -3
  18. build/torch28-cxx11-cu129-x86_64-linux/megablocks/layers.py +258 -34
build/torch27-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f00f02cb159ccecc961af4ceab76fbebd06b61569f8b109a1c63cbcf9cf4a02
3
  size 10513752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ee50c722d5ff355fd4e91d557dffe3be9b674dd5901748dc19286aef37d6d60
3
  size 10513752
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_3bdb4b8_dirty
3
- ops = torch.ops._megablocks_3bdb4b8_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_3bdb4b8_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_8176cbe_dirty
3
+ ops = torch.ops._megablocks_8176cbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_8176cbe_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -1,11 +1,200 @@
1
  import torch
2
  import torch.distributed as dist
3
 
4
- from typing import Optional, Any
5
 
6
  from . import _layers
7
  from . import ops
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Set the expert model parallel attributes on a tensor
11
  def set_expert_model_parallel_attributes(
@@ -80,6 +269,7 @@ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
80
  def route_tokens(
81
  x: torch.Tensor,
82
  router_weight: torch.Tensor,
 
83
  moe_top_k: int,
84
  moe_num_experts: int,
85
  moe_jitter_eps: float = None,
@@ -91,7 +281,7 @@ def route_tokens(
91
  x = apply_jitter(x, moe_jitter_eps)
92
 
93
  x_flat = x.view(-1, x.shape[-1])
94
- logits = torch.nn.functional.linear(x_flat, router_weight)
95
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
  expert_weights = expert_weights.softmax(dim=-1)
97
  if moe_normalize_expert_weights is not None:
@@ -129,6 +319,7 @@ def mlp_forward(
129
  w2_bias: torch.Tensor,
130
  gradient_scale: Optional[float] = None,
131
  alpha: float = 1.702,
 
132
  ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
@@ -144,13 +335,13 @@ def mlp_forward(
144
 
145
  # Forward pass
146
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
147
- gate, up = gate_up.chunk(2, dim=-1)
148
-
 
149
  glu = gate * torch.sigmoid(gate * alpha)
150
- x = (up + 1) * glu
151
-
152
- return torch.bmm(x, w2) + w2_bias[..., None, :]
153
-
154
 
155
  # Shared expert MLP forward pass
156
  def shared_mlp_forward(
@@ -184,13 +375,13 @@ def shared_mlp_forward(
184
 
185
  # Up projection
186
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
-
188
  # Activation
189
  x = activation_fn(x)
190
-
191
  # Down projection
192
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
-
194
  return x
195
 
196
 
@@ -657,6 +848,7 @@ def parallel_forward_once(
657
  def moe_forward(
658
  x: torch.Tensor,
659
  router_weight: torch.Tensor,
 
660
  moe_top_k: int,
661
  moe_num_experts: int,
662
  moe_jitter_eps: float = None,
@@ -682,6 +874,7 @@ def moe_forward(
682
  logits, expert_weights, expert_indices = route_tokens(
683
  x,
684
  router_weight,
 
685
  moe_top_k,
686
  moe_num_experts,
687
  moe_jitter_eps,
@@ -743,6 +936,7 @@ def moe_forward(
743
  def moe_forward_with_shared_expert(
744
  x: torch.Tensor,
745
  router_weight: torch.Tensor,
 
746
  moe_top_k: int,
747
  moe_num_experts: int,
748
  moe_jitter_eps: float = None,
@@ -775,6 +969,7 @@ def moe_forward_with_shared_expert(
775
  expert_out, expert_weights, router_scores = moe_forward(
776
  x=x,
777
  router_weight=router_weight,
 
778
  moe_top_k=moe_top_k,
779
  moe_num_experts=moe_num_experts,
780
  moe_jitter_eps=moe_jitter_eps,
@@ -795,7 +990,7 @@ def moe_forward_with_shared_expert(
795
  hidden_size=hidden_size,
796
  mlp_impl=mlp_impl,
797
  )
798
-
799
  # If shared expert weights provided, compute shared expert output
800
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
  shared_expert_out = shared_mlp_forward(
@@ -807,7 +1002,7 @@ def moe_forward_with_shared_expert(
807
  activation_fn=shared_activation_fn,
808
  gradient_scale=gradient_scale,
809
  )
810
-
811
  # Combine expert outputs
812
  combined_out = combine_expert_shared_outputs(
813
  shared_expert_out=shared_expert_out,
@@ -815,9 +1010,9 @@ def moe_forward_with_shared_expert(
815
  shared_expert_weighted_sum=shared_expert_weighted_sum,
816
  moe_top_k=moe_top_k,
817
  )
818
-
819
  return combined_out, expert_weights, router_scores
820
-
821
  # Return regular MoE output if no shared expert
822
  return expert_out, expert_weights, router_scores
823
 
@@ -833,7 +1028,7 @@ def create_shared_expert_weights(
833
 
834
  if output_layer_init_method is None:
835
  output_layer_init_method = init_method
836
-
837
  # Create weight tensors
838
  up_proj_weight = torch.empty(
839
  shared_expert_hidden_size,
@@ -847,14 +1042,15 @@ def create_shared_expert_weights(
847
  device=device,
848
  dtype=dtype,
849
  )
850
-
851
  # Initialize weights
852
  init_method(up_proj_weight)
853
  output_layer_init_method(down_proj_weight)
854
-
855
  # No bias by default
856
  return up_proj_weight, down_proj_weight, None, None
857
 
 
858
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
  # This exists because device_mesh is trapped in hook closures with no model attribute
860
  # Fragile - breaks if hook structure changes or Python internals change
@@ -863,14 +1059,21 @@ def get_device_mesh(model):
863
  # Extract device_mesh from child's unused pre_hook closure
864
  try:
865
  # Find the pre-hook that contains 'device_mesh' in its closure
866
- hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
 
 
 
 
867
  # Extract the device_mesh from the closure
868
- return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
 
 
869
  except Exception:
870
  return None
871
 
872
 
873
  class MegaBlocksMoeMLP(torch.nn.Module):
 
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
876
  moe_top_k = getattr(self.router, "top_k", 4)
@@ -879,7 +1082,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
879
  alpha = getattr(self.experts, "alpha", 1.0)
880
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
 
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -887,15 +1092,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
887
  device_mesh = get_device_mesh(self)
888
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
 
890
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
-
893
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
894
  mlp_impl = getattr(self, "mlp_impl", "grouped")
895
-
896
  output, expert_weights_out, *_ = moe_forward(
897
  x=x,
898
  router_weight=self.router.weight,
 
899
  moe_top_k=moe_top_k,
900
  moe_num_experts=moe_num_experts,
901
  moe_jitter_eps=moe_jitter_eps,
@@ -919,8 +1130,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
919
  return output, expert_weights_out
920
 
921
 
 
 
 
 
922
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
-
924
  def __init__(self):
925
  super().__init__()
926
  # Shared expert weights will be set by the user
@@ -930,7 +1145,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
930
  self.shared_down_proj_bias = None
931
  self.shared_expert_weighted_sum = False
932
  self.shared_activation_fn = None
933
-
934
  def set_shared_expert_weights(
935
  self,
936
  up_proj_weight: torch.Tensor,
@@ -946,7 +1161,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
946
  self.shared_down_proj_bias = down_proj_bias
947
  self.shared_expert_weighted_sum = weighted_sum
948
  self.shared_activation_fn = activation_fn
949
-
950
  def forward(self, x: torch.Tensor) -> torch.Tensor:
951
  moe_top_k = getattr(self.router, "top_k", 4)
952
  moe_num_experts = getattr(self.experts, "num_experts", 128)
@@ -954,7 +1169,9 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
954
  alpha = getattr(self.experts, "alpha", 1.0)
955
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
958
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
 
960
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -962,15 +1179,22 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
962
  device_mesh = get_device_mesh(self)
963
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
 
965
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
966
  forward_fn = parallel_forward_once if has_parallel else forward_once
967
-
968
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
969
  mlp_impl = getattr(self, "mlp_impl", "grouped")
970
-
971
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
  x=x,
973
  router_weight=self.router.weight,
 
974
  moe_top_k=moe_top_k,
975
  moe_num_experts=moe_num_experts,
976
  moe_jitter_eps=moe_jitter_eps,
@@ -998,4 +1222,4 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
998
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
  shared_activation_fn=self.shared_activation_fn,
1000
  )
1001
- return output, expert_weights_out
 
1
  import torch
2
  import torch.distributed as dist
3
 
4
+ from typing import Optional, Any, TYPE_CHECKING
5
 
6
  from . import _layers
7
  from . import ops
8
 
9
+ # Conditional import for meta kernel registration
10
+ if TYPE_CHECKING:
11
+
12
+ def register_fake(fn):
13
+ return lambda name: fn
14
+
15
+ else:
16
+ try:
17
+ from torch.library import register_fake
18
+ except ImportError:
19
+ try:
20
+ from torch.library import impl_abstract as register_fake
21
+ except ImportError:
22
+ # Fallback for older PyTorch versions
23
+ def register_fake(op_name):
24
+ def decorator(fn):
25
+ return fn
26
+
27
+ return decorator
28
+
29
+
30
+ # Meta kernel implementations for torch.compile compatibility
31
+ def _install_meta_kernels():
32
+ """Install meta kernels for existing MegaBlocks operations"""
33
+
34
+ # Create wrapper functions that check for compilation and return meta tensors
35
+
36
+ # Patch ops.sort
37
+ if hasattr(ops, "sort"):
38
+ original_sort = ops.sort
39
+
40
+ def sort_with_meta(x, end_bit=None):
41
+ if torch.compiler.is_compiling():
42
+ print("Using meta kernel for sort")
43
+ # Meta implementation - return tensors with correct shape/dtype/device
44
+ return torch.empty_like(x), torch.empty_like(x)
45
+ # print("Using original sort kernel")
46
+ return original_sort(x, end_bit)
47
+
48
+ ops.sort = sort_with_meta
49
+
50
+ # Patch ops.histogram
51
+ if hasattr(ops, "histogram"):
52
+ original_histogram = ops.histogram
53
+
54
+ def histogram_with_meta(x, max_val):
55
+ if torch.compiler.is_compiling():
56
+ # Meta implementation
57
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
58
+ return original_histogram(x, max_val)
59
+
60
+ ops.histogram = histogram_with_meta
61
+
62
+ # Patch ops.inclusive_cumsum
63
+ if hasattr(ops, "inclusive_cumsum"):
64
+ original_inclusive_cumsum = ops.inclusive_cumsum
65
+
66
+ def inclusive_cumsum_with_meta(x, dim):
67
+ if torch.compiler.is_compiling():
68
+ # Meta implementation
69
+ return torch.empty_like(x)
70
+ return original_inclusive_cumsum(x, dim)
71
+
72
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
73
+
74
+ # Patch ops.binned_gather
75
+ if hasattr(ops, "binned_gather"):
76
+ original_binned_gather = ops.binned_gather
77
+
78
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
79
+ if torch.compiler.is_compiling():
80
+ # Meta implementation - output shape based on bin_size
81
+ if x.dim() >= 2:
82
+ hidden_size = x.size(-1)
83
+ return torch.empty(
84
+ (bin_size, x.size(1), hidden_size),
85
+ dtype=x.dtype,
86
+ device=x.device,
87
+ )
88
+ else:
89
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
90
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
91
+
92
+ ops.binned_gather = binned_gather_with_meta
93
+
94
+ # Patch ops.binned_scatter
95
+ if hasattr(ops, "binned_scatter"):
96
+ original_binned_scatter = ops.binned_scatter
97
+
98
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
99
+ if torch.compiler.is_compiling():
100
+ # Meta implementation - typically reduces to 2D
101
+ if x.dim() >= 3:
102
+ return torch.empty(
103
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
104
+ )
105
+ else:
106
+ return torch.empty_like(x)
107
+ return original_binned_scatter(x, indices, weights, bins, top_k)
108
+
109
+ ops.binned_scatter = binned_scatter_with_meta
110
+
111
+ # Patch ops.gather
112
+ if hasattr(ops, "gather"):
113
+ original_gather = ops.gather
114
+
115
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
116
+ if torch.compiler.is_compiling():
117
+ # Meta implementation
118
+ if x.dim() >= 2:
119
+ hidden_size = x.size(-1)
120
+ return torch.empty(
121
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
122
+ )
123
+ else:
124
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
125
+ return original_gather(x, indices, bin_ids, bins, top_k)
126
+
127
+ ops.gather = gather_with_meta
128
+
129
+ # Patch ops.scatter
130
+ if hasattr(ops, "scatter"):
131
+ original_scatter = ops.scatter
132
+
133
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
134
+ if torch.compiler.is_compiling():
135
+ # Meta implementation - restore sequence shape
136
+ seq_len = (
137
+ indices.size(0) // top_k
138
+ if indices.numel() > 0 and top_k > 0
139
+ else x.size(0)
140
+ )
141
+ if x.dim() >= 2:
142
+ return torch.empty(
143
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
144
+ )
145
+ else:
146
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
147
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
148
+
149
+ ops.scatter = scatter_with_meta
150
+
151
+ # Patch ops.replicate
152
+ if hasattr(ops, "replicate"):
153
+ original_replicate = ops.replicate
154
+
155
+ def replicate_with_meta(x, bins, num_outputs):
156
+ if torch.compiler.is_compiling():
157
+ # Meta implementation
158
+ return torch.empty(
159
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
160
+ )
161
+ return original_replicate(x, bins, num_outputs)
162
+
163
+ ops.replicate = replicate_with_meta
164
+
165
+ # Patch ops.repeat (if it's a regular function)
166
+ if hasattr(ops, "repeat"):
167
+ original_repeat = ops.repeat
168
+
169
+ def repeat_with_meta(x, repeats):
170
+ if torch.compiler.is_compiling():
171
+ # Meta implementation
172
+ if isinstance(repeats, (tuple, list)):
173
+ new_shape = list(x.shape)
174
+ for i, rep in enumerate(repeats):
175
+ if i < len(new_shape):
176
+ new_shape[i] *= rep
177
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
178
+ else:
179
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
180
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
181
+ return original_repeat(x, repeats)
182
+
183
+ ops.repeat = repeat_with_meta
184
+
185
+
186
+ # Install meta kernels on import
187
+ try:
188
+ _install_meta_kernels()
189
+ except Exception as e:
190
+ # If meta kernel installation fails, continue without them
191
+ # torch.compile may not work but the library will still function
192
+ import warnings
193
+
194
+ warnings.warn(
195
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
196
+ )
197
+
198
 
199
  # Set the expert model parallel attributes on a tensor
200
  def set_expert_model_parallel_attributes(
 
269
  def route_tokens(
270
  x: torch.Tensor,
271
  router_weight: torch.Tensor,
272
+ router_bias: torch.Tensor,
273
  moe_top_k: int,
274
  moe_num_experts: int,
275
  moe_jitter_eps: float = None,
 
281
  x = apply_jitter(x, moe_jitter_eps)
282
 
283
  x_flat = x.view(-1, x.shape[-1])
284
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
285
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
286
  expert_weights = expert_weights.softmax(dim=-1)
287
  if moe_normalize_expert_weights is not None:
 
319
  w2_bias: torch.Tensor,
320
  gradient_scale: Optional[float] = None,
321
  alpha: float = 1.702,
322
+ limit: float = 7.0,
323
  ):
324
  # Scale weights
325
  w1 = scale_grad(w1, gradient_scale)
 
335
 
336
  # Forward pass
337
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
338
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
339
+ gate = gate.clamp(min=None, max=limit)
340
+ up = up.clamp(min=-limit, max=limit)
341
  glu = gate * torch.sigmoid(gate * alpha)
342
+ next_states = torch.bmm(((up + 1) * glu), w2)
343
+ next_states += w2_bias[..., None, :]
344
+ return next_states
 
345
 
346
  # Shared expert MLP forward pass
347
  def shared_mlp_forward(
 
375
 
376
  # Up projection
377
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
378
+
379
  # Activation
380
  x = activation_fn(x)
381
+
382
  # Down projection
383
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
384
+
385
  return x
386
 
387
 
 
848
  def moe_forward(
849
  x: torch.Tensor,
850
  router_weight: torch.Tensor,
851
+ router_bias: Optional[torch.Tensor],
852
  moe_top_k: int,
853
  moe_num_experts: int,
854
  moe_jitter_eps: float = None,
 
874
  logits, expert_weights, expert_indices = route_tokens(
875
  x,
876
  router_weight,
877
+ router_bias,
878
  moe_top_k,
879
  moe_num_experts,
880
  moe_jitter_eps,
 
936
  def moe_forward_with_shared_expert(
937
  x: torch.Tensor,
938
  router_weight: torch.Tensor,
939
+ router_bias: Optional[torch.Tensor],
940
  moe_top_k: int,
941
  moe_num_experts: int,
942
  moe_jitter_eps: float = None,
 
969
  expert_out, expert_weights, router_scores = moe_forward(
970
  x=x,
971
  router_weight=router_weight,
972
+ router_bias=router_bias,
973
  moe_top_k=moe_top_k,
974
  moe_num_experts=moe_num_experts,
975
  moe_jitter_eps=moe_jitter_eps,
 
990
  hidden_size=hidden_size,
991
  mlp_impl=mlp_impl,
992
  )
993
+
994
  # If shared expert weights provided, compute shared expert output
995
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
996
  shared_expert_out = shared_mlp_forward(
 
1002
  activation_fn=shared_activation_fn,
1003
  gradient_scale=gradient_scale,
1004
  )
1005
+
1006
  # Combine expert outputs
1007
  combined_out = combine_expert_shared_outputs(
1008
  shared_expert_out=shared_expert_out,
 
1010
  shared_expert_weighted_sum=shared_expert_weighted_sum,
1011
  moe_top_k=moe_top_k,
1012
  )
1013
+
1014
  return combined_out, expert_weights, router_scores
1015
+
1016
  # Return regular MoE output if no shared expert
1017
  return expert_out, expert_weights, router_scores
1018
 
 
1028
 
1029
  if output_layer_init_method is None:
1030
  output_layer_init_method = init_method
1031
+
1032
  # Create weight tensors
1033
  up_proj_weight = torch.empty(
1034
  shared_expert_hidden_size,
 
1042
  device=device,
1043
  dtype=dtype,
1044
  )
1045
+
1046
  # Initialize weights
1047
  init_method(up_proj_weight)
1048
  output_layer_init_method(down_proj_weight)
1049
+
1050
  # No bias by default
1051
  return up_proj_weight, down_proj_weight, None, None
1052
 
1053
+
1054
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
1055
  # This exists because device_mesh is trapped in hook closures with no model attribute
1056
  # Fragile - breaks if hook structure changes or Python internals change
 
1059
  # Extract device_mesh from child's unused pre_hook closure
1060
  try:
1061
  # Find the pre-hook that contains 'device_mesh' in its closure
1062
+ hook = next(
1063
+ h
1064
+ for h in model.experts._forward_pre_hooks.values()
1065
+ if "device_mesh" in h.__code__.co_freevars
1066
+ )
1067
  # Extract the device_mesh from the closure
1068
+ return hook.__closure__[
1069
+ hook.__code__.co_freevars.index("device_mesh")
1070
+ ].cell_contents
1071
  except Exception:
1072
  return None
1073
 
1074
 
1075
  class MegaBlocksMoeMLP(torch.nn.Module):
1076
+ can_torch_compile: bool = True
1077
 
1078
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1079
  moe_top_k = getattr(self.router, "top_k", 4)
 
1082
  alpha = getattr(self.experts, "alpha", 1.0)
1083
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1084
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1085
+ moe_normalize_expert_weights = getattr(
1086
+ self.experts, "normalize_expert_weights", None
1087
+ )
1088
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1089
 
1090
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1092
  device_mesh = get_device_mesh(self)
1093
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1094
 
1095
+ has_parallel = (
1096
+ expert_parallel_group is not None
1097
+ and dist.is_initialized()
1098
+ and dist.get_world_size(expert_parallel_group) > 1
1099
+ )
1100
  forward_fn = parallel_forward_once if has_parallel else forward_once
1101
+
1102
+ sort_end_bit = max(
1103
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1104
+ )
1105
  mlp_impl = getattr(self, "mlp_impl", "grouped")
 
1106
  output, expert_weights_out, *_ = moe_forward(
1107
  x=x,
1108
  router_weight=self.router.weight,
1109
+ router_bias=self.router.bias,
1110
  moe_top_k=moe_top_k,
1111
  moe_num_experts=moe_num_experts,
1112
  moe_jitter_eps=moe_jitter_eps,
 
1130
  return output, expert_weights_out
1131
 
1132
 
1133
+ # Export main classes
1134
+ __all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
1135
+
1136
+
1137
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
1138
+
1139
  def __init__(self):
1140
  super().__init__()
1141
  # Shared expert weights will be set by the user
 
1145
  self.shared_down_proj_bias = None
1146
  self.shared_expert_weighted_sum = False
1147
  self.shared_activation_fn = None
1148
+
1149
  def set_shared_expert_weights(
1150
  self,
1151
  up_proj_weight: torch.Tensor,
 
1161
  self.shared_down_proj_bias = down_proj_bias
1162
  self.shared_expert_weighted_sum = weighted_sum
1163
  self.shared_activation_fn = activation_fn
1164
+
1165
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1166
  moe_top_k = getattr(self.router, "top_k", 4)
1167
  moe_num_experts = getattr(self.experts, "num_experts", 128)
 
1169
  alpha = getattr(self.experts, "alpha", 1.0)
1170
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1171
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1172
+ moe_normalize_expert_weights = getattr(
1173
+ self.experts, "normalize_expert_weights", None
1174
+ )
1175
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1176
 
1177
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1179
  device_mesh = get_device_mesh(self)
1180
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1181
 
1182
+ has_parallel = (
1183
+ expert_parallel_group is not None
1184
+ and dist.is_initialized()
1185
+ and dist.get_world_size(expert_parallel_group) > 1
1186
+ )
1187
  forward_fn = parallel_forward_once if has_parallel else forward_once
1188
+
1189
+ sort_end_bit = max(
1190
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1191
+ )
1192
  mlp_impl = getattr(self, "mlp_impl", "grouped")
1193
+
1194
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
1195
  x=x,
1196
  router_weight=self.router.weight,
1197
+ router_bias=self.router.bias,
1198
  moe_top_k=moe_top_k,
1199
  moe_num_experts=moe_num_experts,
1200
  moe_jitter_eps=moe_jitter_eps,
 
1222
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
1223
  shared_activation_fn=self.shared_activation_fn,
1224
  )
1225
+ return output, expert_weights_out
build/torch27-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f286efaecea4ae5f49c6c661285e5a8b40808908b5382f810b0941295ef0ae4d
3
  size 11927016
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eff5def5d00d090fb74d18f9d5101d8f41d71adc70fd478407380f0813a8ba44
3
  size 11927016
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_3bdb4b8_dirty
3
- ops = torch.ops._megablocks_3bdb4b8_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_3bdb4b8_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_8176cbe_dirty
3
+ ops = torch.ops._megablocks_8176cbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_8176cbe_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -1,11 +1,200 @@
1
  import torch
2
  import torch.distributed as dist
3
 
4
- from typing import Optional, Any
5
 
6
  from . import _layers
7
  from . import ops
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Set the expert model parallel attributes on a tensor
11
  def set_expert_model_parallel_attributes(
@@ -80,6 +269,7 @@ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
80
  def route_tokens(
81
  x: torch.Tensor,
82
  router_weight: torch.Tensor,
 
83
  moe_top_k: int,
84
  moe_num_experts: int,
85
  moe_jitter_eps: float = None,
@@ -91,7 +281,7 @@ def route_tokens(
91
  x = apply_jitter(x, moe_jitter_eps)
92
 
93
  x_flat = x.view(-1, x.shape[-1])
94
- logits = torch.nn.functional.linear(x_flat, router_weight)
95
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
  expert_weights = expert_weights.softmax(dim=-1)
97
  if moe_normalize_expert_weights is not None:
@@ -129,6 +319,7 @@ def mlp_forward(
129
  w2_bias: torch.Tensor,
130
  gradient_scale: Optional[float] = None,
131
  alpha: float = 1.702,
 
132
  ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
@@ -144,13 +335,13 @@ def mlp_forward(
144
 
145
  # Forward pass
146
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
147
- gate, up = gate_up.chunk(2, dim=-1)
148
-
 
149
  glu = gate * torch.sigmoid(gate * alpha)
150
- x = (up + 1) * glu
151
-
152
- return torch.bmm(x, w2) + w2_bias[..., None, :]
153
-
154
 
155
  # Shared expert MLP forward pass
156
  def shared_mlp_forward(
@@ -184,13 +375,13 @@ def shared_mlp_forward(
184
 
185
  # Up projection
186
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
-
188
  # Activation
189
  x = activation_fn(x)
190
-
191
  # Down projection
192
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
-
194
  return x
195
 
196
 
@@ -657,6 +848,7 @@ def parallel_forward_once(
657
  def moe_forward(
658
  x: torch.Tensor,
659
  router_weight: torch.Tensor,
 
660
  moe_top_k: int,
661
  moe_num_experts: int,
662
  moe_jitter_eps: float = None,
@@ -682,6 +874,7 @@ def moe_forward(
682
  logits, expert_weights, expert_indices = route_tokens(
683
  x,
684
  router_weight,
 
685
  moe_top_k,
686
  moe_num_experts,
687
  moe_jitter_eps,
@@ -743,6 +936,7 @@ def moe_forward(
743
  def moe_forward_with_shared_expert(
744
  x: torch.Tensor,
745
  router_weight: torch.Tensor,
 
746
  moe_top_k: int,
747
  moe_num_experts: int,
748
  moe_jitter_eps: float = None,
@@ -775,6 +969,7 @@ def moe_forward_with_shared_expert(
775
  expert_out, expert_weights, router_scores = moe_forward(
776
  x=x,
777
  router_weight=router_weight,
 
778
  moe_top_k=moe_top_k,
779
  moe_num_experts=moe_num_experts,
780
  moe_jitter_eps=moe_jitter_eps,
@@ -795,7 +990,7 @@ def moe_forward_with_shared_expert(
795
  hidden_size=hidden_size,
796
  mlp_impl=mlp_impl,
797
  )
798
-
799
  # If shared expert weights provided, compute shared expert output
800
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
  shared_expert_out = shared_mlp_forward(
@@ -807,7 +1002,7 @@ def moe_forward_with_shared_expert(
807
  activation_fn=shared_activation_fn,
808
  gradient_scale=gradient_scale,
809
  )
810
-
811
  # Combine expert outputs
812
  combined_out = combine_expert_shared_outputs(
813
  shared_expert_out=shared_expert_out,
@@ -815,9 +1010,9 @@ def moe_forward_with_shared_expert(
815
  shared_expert_weighted_sum=shared_expert_weighted_sum,
816
  moe_top_k=moe_top_k,
817
  )
818
-
819
  return combined_out, expert_weights, router_scores
820
-
821
  # Return regular MoE output if no shared expert
822
  return expert_out, expert_weights, router_scores
823
 
@@ -833,7 +1028,7 @@ def create_shared_expert_weights(
833
 
834
  if output_layer_init_method is None:
835
  output_layer_init_method = init_method
836
-
837
  # Create weight tensors
838
  up_proj_weight = torch.empty(
839
  shared_expert_hidden_size,
@@ -847,14 +1042,15 @@ def create_shared_expert_weights(
847
  device=device,
848
  dtype=dtype,
849
  )
850
-
851
  # Initialize weights
852
  init_method(up_proj_weight)
853
  output_layer_init_method(down_proj_weight)
854
-
855
  # No bias by default
856
  return up_proj_weight, down_proj_weight, None, None
857
 
 
858
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
  # This exists because device_mesh is trapped in hook closures with no model attribute
860
  # Fragile - breaks if hook structure changes or Python internals change
@@ -863,14 +1059,21 @@ def get_device_mesh(model):
863
  # Extract device_mesh from child's unused pre_hook closure
864
  try:
865
  # Find the pre-hook that contains 'device_mesh' in its closure
866
- hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
 
 
 
 
867
  # Extract the device_mesh from the closure
868
- return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
 
 
869
  except Exception:
870
  return None
871
 
872
 
873
  class MegaBlocksMoeMLP(torch.nn.Module):
 
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
876
  moe_top_k = getattr(self.router, "top_k", 4)
@@ -879,7 +1082,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
879
  alpha = getattr(self.experts, "alpha", 1.0)
880
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
 
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -887,15 +1092,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
887
  device_mesh = get_device_mesh(self)
888
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
 
890
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
-
893
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
894
  mlp_impl = getattr(self, "mlp_impl", "grouped")
895
-
896
  output, expert_weights_out, *_ = moe_forward(
897
  x=x,
898
  router_weight=self.router.weight,
 
899
  moe_top_k=moe_top_k,
900
  moe_num_experts=moe_num_experts,
901
  moe_jitter_eps=moe_jitter_eps,
@@ -919,8 +1130,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
919
  return output, expert_weights_out
920
 
921
 
 
 
 
 
922
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
-
924
  def __init__(self):
925
  super().__init__()
926
  # Shared expert weights will be set by the user
@@ -930,7 +1145,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
930
  self.shared_down_proj_bias = None
931
  self.shared_expert_weighted_sum = False
932
  self.shared_activation_fn = None
933
-
934
  def set_shared_expert_weights(
935
  self,
936
  up_proj_weight: torch.Tensor,
@@ -946,7 +1161,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
946
  self.shared_down_proj_bias = down_proj_bias
947
  self.shared_expert_weighted_sum = weighted_sum
948
  self.shared_activation_fn = activation_fn
949
-
950
  def forward(self, x: torch.Tensor) -> torch.Tensor:
951
  moe_top_k = getattr(self.router, "top_k", 4)
952
  moe_num_experts = getattr(self.experts, "num_experts", 128)
@@ -954,7 +1169,9 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
954
  alpha = getattr(self.experts, "alpha", 1.0)
955
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
958
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
 
960
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -962,15 +1179,22 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
962
  device_mesh = get_device_mesh(self)
963
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
 
965
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
966
  forward_fn = parallel_forward_once if has_parallel else forward_once
967
-
968
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
969
  mlp_impl = getattr(self, "mlp_impl", "grouped")
970
-
971
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
  x=x,
973
  router_weight=self.router.weight,
 
974
  moe_top_k=moe_top_k,
975
  moe_num_experts=moe_num_experts,
976
  moe_jitter_eps=moe_jitter_eps,
@@ -998,4 +1222,4 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
998
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
  shared_activation_fn=self.shared_activation_fn,
1000
  )
1001
- return output, expert_weights_out
 
1
  import torch
2
  import torch.distributed as dist
3
 
4
+ from typing import Optional, Any, TYPE_CHECKING
5
 
6
  from . import _layers
7
  from . import ops
8
 
9
+ # Conditional import for meta kernel registration
10
+ if TYPE_CHECKING:
11
+
12
+ def register_fake(fn):
13
+ return lambda name: fn
14
+
15
+ else:
16
+ try:
17
+ from torch.library import register_fake
18
+ except ImportError:
19
+ try:
20
+ from torch.library import impl_abstract as register_fake
21
+ except ImportError:
22
+ # Fallback for older PyTorch versions
23
+ def register_fake(op_name):
24
+ def decorator(fn):
25
+ return fn
26
+
27
+ return decorator
28
+
29
+
30
+ # Meta kernel implementations for torch.compile compatibility
31
+ def _install_meta_kernels():
32
+ """Install meta kernels for existing MegaBlocks operations"""
33
+
34
+ # Create wrapper functions that check for compilation and return meta tensors
35
+
36
+ # Patch ops.sort
37
+ if hasattr(ops, "sort"):
38
+ original_sort = ops.sort
39
+
40
+ def sort_with_meta(x, end_bit=None):
41
+ if torch.compiler.is_compiling():
42
+ print("Using meta kernel for sort")
43
+ # Meta implementation - return tensors with correct shape/dtype/device
44
+ return torch.empty_like(x), torch.empty_like(x)
45
+ # print("Using original sort kernel")
46
+ return original_sort(x, end_bit)
47
+
48
+ ops.sort = sort_with_meta
49
+
50
+ # Patch ops.histogram
51
+ if hasattr(ops, "histogram"):
52
+ original_histogram = ops.histogram
53
+
54
+ def histogram_with_meta(x, max_val):
55
+ if torch.compiler.is_compiling():
56
+ # Meta implementation
57
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
58
+ return original_histogram(x, max_val)
59
+
60
+ ops.histogram = histogram_with_meta
61
+
62
+ # Patch ops.inclusive_cumsum
63
+ if hasattr(ops, "inclusive_cumsum"):
64
+ original_inclusive_cumsum = ops.inclusive_cumsum
65
+
66
+ def inclusive_cumsum_with_meta(x, dim):
67
+ if torch.compiler.is_compiling():
68
+ # Meta implementation
69
+ return torch.empty_like(x)
70
+ return original_inclusive_cumsum(x, dim)
71
+
72
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
73
+
74
+ # Patch ops.binned_gather
75
+ if hasattr(ops, "binned_gather"):
76
+ original_binned_gather = ops.binned_gather
77
+
78
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
79
+ if torch.compiler.is_compiling():
80
+ # Meta implementation - output shape based on bin_size
81
+ if x.dim() >= 2:
82
+ hidden_size = x.size(-1)
83
+ return torch.empty(
84
+ (bin_size, x.size(1), hidden_size),
85
+ dtype=x.dtype,
86
+ device=x.device,
87
+ )
88
+ else:
89
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
90
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
91
+
92
+ ops.binned_gather = binned_gather_with_meta
93
+
94
+ # Patch ops.binned_scatter
95
+ if hasattr(ops, "binned_scatter"):
96
+ original_binned_scatter = ops.binned_scatter
97
+
98
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
99
+ if torch.compiler.is_compiling():
100
+ # Meta implementation - typically reduces to 2D
101
+ if x.dim() >= 3:
102
+ return torch.empty(
103
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
104
+ )
105
+ else:
106
+ return torch.empty_like(x)
107
+ return original_binned_scatter(x, indices, weights, bins, top_k)
108
+
109
+ ops.binned_scatter = binned_scatter_with_meta
110
+
111
+ # Patch ops.gather
112
+ if hasattr(ops, "gather"):
113
+ original_gather = ops.gather
114
+
115
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
116
+ if torch.compiler.is_compiling():
117
+ # Meta implementation
118
+ if x.dim() >= 2:
119
+ hidden_size = x.size(-1)
120
+ return torch.empty(
121
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
122
+ )
123
+ else:
124
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
125
+ return original_gather(x, indices, bin_ids, bins, top_k)
126
+
127
+ ops.gather = gather_with_meta
128
+
129
+ # Patch ops.scatter
130
+ if hasattr(ops, "scatter"):
131
+ original_scatter = ops.scatter
132
+
133
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
134
+ if torch.compiler.is_compiling():
135
+ # Meta implementation - restore sequence shape
136
+ seq_len = (
137
+ indices.size(0) // top_k
138
+ if indices.numel() > 0 and top_k > 0
139
+ else x.size(0)
140
+ )
141
+ if x.dim() >= 2:
142
+ return torch.empty(
143
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
144
+ )
145
+ else:
146
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
147
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
148
+
149
+ ops.scatter = scatter_with_meta
150
+
151
+ # Patch ops.replicate
152
+ if hasattr(ops, "replicate"):
153
+ original_replicate = ops.replicate
154
+
155
+ def replicate_with_meta(x, bins, num_outputs):
156
+ if torch.compiler.is_compiling():
157
+ # Meta implementation
158
+ return torch.empty(
159
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
160
+ )
161
+ return original_replicate(x, bins, num_outputs)
162
+
163
+ ops.replicate = replicate_with_meta
164
+
165
+ # Patch ops.repeat (if it's a regular function)
166
+ if hasattr(ops, "repeat"):
167
+ original_repeat = ops.repeat
168
+
169
+ def repeat_with_meta(x, repeats):
170
+ if torch.compiler.is_compiling():
171
+ # Meta implementation
172
+ if isinstance(repeats, (tuple, list)):
173
+ new_shape = list(x.shape)
174
+ for i, rep in enumerate(repeats):
175
+ if i < len(new_shape):
176
+ new_shape[i] *= rep
177
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
178
+ else:
179
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
180
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
181
+ return original_repeat(x, repeats)
182
+
183
+ ops.repeat = repeat_with_meta
184
+
185
+
186
+ # Install meta kernels on import
187
+ try:
188
+ _install_meta_kernels()
189
+ except Exception as e:
190
+ # If meta kernel installation fails, continue without them
191
+ # torch.compile may not work but the library will still function
192
+ import warnings
193
+
194
+ warnings.warn(
195
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
196
+ )
197
+
198
 
199
  # Set the expert model parallel attributes on a tensor
200
  def set_expert_model_parallel_attributes(
 
269
  def route_tokens(
270
  x: torch.Tensor,
271
  router_weight: torch.Tensor,
272
+ router_bias: torch.Tensor,
273
  moe_top_k: int,
274
  moe_num_experts: int,
275
  moe_jitter_eps: float = None,
 
281
  x = apply_jitter(x, moe_jitter_eps)
282
 
283
  x_flat = x.view(-1, x.shape[-1])
284
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
285
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
286
  expert_weights = expert_weights.softmax(dim=-1)
287
  if moe_normalize_expert_weights is not None:
 
319
  w2_bias: torch.Tensor,
320
  gradient_scale: Optional[float] = None,
321
  alpha: float = 1.702,
322
+ limit: float = 7.0,
323
  ):
324
  # Scale weights
325
  w1 = scale_grad(w1, gradient_scale)
 
335
 
336
  # Forward pass
337
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
338
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
339
+ gate = gate.clamp(min=None, max=limit)
340
+ up = up.clamp(min=-limit, max=limit)
341
  glu = gate * torch.sigmoid(gate * alpha)
342
+ next_states = torch.bmm(((up + 1) * glu), w2)
343
+ next_states += w2_bias[..., None, :]
344
+ return next_states
 
345
 
346
  # Shared expert MLP forward pass
347
  def shared_mlp_forward(
 
375
 
376
  # Up projection
377
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
378
+
379
  # Activation
380
  x = activation_fn(x)
381
+
382
  # Down projection
383
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
384
+
385
  return x
386
 
387
 
 
848
  def moe_forward(
849
  x: torch.Tensor,
850
  router_weight: torch.Tensor,
851
+ router_bias: Optional[torch.Tensor],
852
  moe_top_k: int,
853
  moe_num_experts: int,
854
  moe_jitter_eps: float = None,
 
874
  logits, expert_weights, expert_indices = route_tokens(
875
  x,
876
  router_weight,
877
+ router_bias,
878
  moe_top_k,
879
  moe_num_experts,
880
  moe_jitter_eps,
 
936
  def moe_forward_with_shared_expert(
937
  x: torch.Tensor,
938
  router_weight: torch.Tensor,
939
+ router_bias: Optional[torch.Tensor],
940
  moe_top_k: int,
941
  moe_num_experts: int,
942
  moe_jitter_eps: float = None,
 
969
  expert_out, expert_weights, router_scores = moe_forward(
970
  x=x,
971
  router_weight=router_weight,
972
+ router_bias=router_bias,
973
  moe_top_k=moe_top_k,
974
  moe_num_experts=moe_num_experts,
975
  moe_jitter_eps=moe_jitter_eps,
 
990
  hidden_size=hidden_size,
991
  mlp_impl=mlp_impl,
992
  )
993
+
994
  # If shared expert weights provided, compute shared expert output
995
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
996
  shared_expert_out = shared_mlp_forward(
 
1002
  activation_fn=shared_activation_fn,
1003
  gradient_scale=gradient_scale,
1004
  )
1005
+
1006
  # Combine expert outputs
1007
  combined_out = combine_expert_shared_outputs(
1008
  shared_expert_out=shared_expert_out,
 
1010
  shared_expert_weighted_sum=shared_expert_weighted_sum,
1011
  moe_top_k=moe_top_k,
1012
  )
1013
+
1014
  return combined_out, expert_weights, router_scores
1015
+
1016
  # Return regular MoE output if no shared expert
1017
  return expert_out, expert_weights, router_scores
1018
 
 
1028
 
1029
  if output_layer_init_method is None:
1030
  output_layer_init_method = init_method
1031
+
1032
  # Create weight tensors
1033
  up_proj_weight = torch.empty(
1034
  shared_expert_hidden_size,
 
1042
  device=device,
1043
  dtype=dtype,
1044
  )
1045
+
1046
  # Initialize weights
1047
  init_method(up_proj_weight)
1048
  output_layer_init_method(down_proj_weight)
1049
+
1050
  # No bias by default
1051
  return up_proj_weight, down_proj_weight, None, None
1052
 
1053
+
1054
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
1055
  # This exists because device_mesh is trapped in hook closures with no model attribute
1056
  # Fragile - breaks if hook structure changes or Python internals change
 
1059
  # Extract device_mesh from child's unused pre_hook closure
1060
  try:
1061
  # Find the pre-hook that contains 'device_mesh' in its closure
1062
+ hook = next(
1063
+ h
1064
+ for h in model.experts._forward_pre_hooks.values()
1065
+ if "device_mesh" in h.__code__.co_freevars
1066
+ )
1067
  # Extract the device_mesh from the closure
1068
+ return hook.__closure__[
1069
+ hook.__code__.co_freevars.index("device_mesh")
1070
+ ].cell_contents
1071
  except Exception:
1072
  return None
1073
 
1074
 
1075
  class MegaBlocksMoeMLP(torch.nn.Module):
1076
+ can_torch_compile: bool = True
1077
 
1078
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1079
  moe_top_k = getattr(self.router, "top_k", 4)
 
1082
  alpha = getattr(self.experts, "alpha", 1.0)
1083
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1084
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1085
+ moe_normalize_expert_weights = getattr(
1086
+ self.experts, "normalize_expert_weights", None
1087
+ )
1088
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1089
 
1090
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1092
  device_mesh = get_device_mesh(self)
1093
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1094
 
1095
+ has_parallel = (
1096
+ expert_parallel_group is not None
1097
+ and dist.is_initialized()
1098
+ and dist.get_world_size(expert_parallel_group) > 1
1099
+ )
1100
  forward_fn = parallel_forward_once if has_parallel else forward_once
1101
+
1102
+ sort_end_bit = max(
1103
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1104
+ )
1105
  mlp_impl = getattr(self, "mlp_impl", "grouped")
 
1106
  output, expert_weights_out, *_ = moe_forward(
1107
  x=x,
1108
  router_weight=self.router.weight,
1109
+ router_bias=self.router.bias,
1110
  moe_top_k=moe_top_k,
1111
  moe_num_experts=moe_num_experts,
1112
  moe_jitter_eps=moe_jitter_eps,
 
1130
  return output, expert_weights_out
1131
 
1132
 
1133
+ # Export main classes
1134
+ __all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
1135
+
1136
+
1137
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
1138
+
1139
  def __init__(self):
1140
  super().__init__()
1141
  # Shared expert weights will be set by the user
 
1145
  self.shared_down_proj_bias = None
1146
  self.shared_expert_weighted_sum = False
1147
  self.shared_activation_fn = None
1148
+
1149
  def set_shared_expert_weights(
1150
  self,
1151
  up_proj_weight: torch.Tensor,
 
1161
  self.shared_down_proj_bias = down_proj_bias
1162
  self.shared_expert_weighted_sum = weighted_sum
1163
  self.shared_activation_fn = activation_fn
1164
+
1165
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1166
  moe_top_k = getattr(self.router, "top_k", 4)
1167
  moe_num_experts = getattr(self.experts, "num_experts", 128)
 
1169
  alpha = getattr(self.experts, "alpha", 1.0)
1170
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1171
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1172
+ moe_normalize_expert_weights = getattr(
1173
+ self.experts, "normalize_expert_weights", None
1174
+ )
1175
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1176
 
1177
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1179
  device_mesh = get_device_mesh(self)
1180
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1181
 
1182
+ has_parallel = (
1183
+ expert_parallel_group is not None
1184
+ and dist.is_initialized()
1185
+ and dist.get_world_size(expert_parallel_group) > 1
1186
+ )
1187
  forward_fn = parallel_forward_once if has_parallel else forward_once
1188
+
1189
+ sort_end_bit = max(
1190
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1191
+ )
1192
  mlp_impl = getattr(self, "mlp_impl", "grouped")
1193
+
1194
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
1195
  x=x,
1196
  router_weight=self.router.weight,
1197
+ router_bias=self.router.bias,
1198
  moe_top_k=moe_top_k,
1199
  moe_num_experts=moe_num_experts,
1200
  moe_jitter_eps=moe_jitter_eps,
 
1222
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
1223
  shared_activation_fn=self.shared_activation_fn,
1224
  )
1225
+ return output, expert_weights_out
build/torch27-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cbc2acbd9421ba25cb0972da6915168f0ff88ce1e5fce547bdd240319945b212
3
  size 17884448
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f22a56b5e69d365a2f077c9713f4eae7d325e201fee3cf705af6663d9cb854a
3
  size 17884448
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_3bdb4b8_dirty
3
- ops = torch.ops._megablocks_3bdb4b8_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_3bdb4b8_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_8176cbe_dirty
3
+ ops = torch.ops._megablocks_8176cbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_8176cbe_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py CHANGED
@@ -1,11 +1,200 @@
1
  import torch
2
  import torch.distributed as dist
3
 
4
- from typing import Optional, Any
5
 
6
  from . import _layers
7
  from . import ops
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Set the expert model parallel attributes on a tensor
11
  def set_expert_model_parallel_attributes(
@@ -80,6 +269,7 @@ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
80
  def route_tokens(
81
  x: torch.Tensor,
82
  router_weight: torch.Tensor,
 
83
  moe_top_k: int,
84
  moe_num_experts: int,
85
  moe_jitter_eps: float = None,
@@ -91,7 +281,7 @@ def route_tokens(
91
  x = apply_jitter(x, moe_jitter_eps)
92
 
93
  x_flat = x.view(-1, x.shape[-1])
94
- logits = torch.nn.functional.linear(x_flat, router_weight)
95
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
  expert_weights = expert_weights.softmax(dim=-1)
97
  if moe_normalize_expert_weights is not None:
@@ -129,6 +319,7 @@ def mlp_forward(
129
  w2_bias: torch.Tensor,
130
  gradient_scale: Optional[float] = None,
131
  alpha: float = 1.702,
 
132
  ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
@@ -144,13 +335,13 @@ def mlp_forward(
144
 
145
  # Forward pass
146
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
147
- gate, up = gate_up.chunk(2, dim=-1)
148
-
 
149
  glu = gate * torch.sigmoid(gate * alpha)
150
- x = (up + 1) * glu
151
-
152
- return torch.bmm(x, w2) + w2_bias[..., None, :]
153
-
154
 
155
  # Shared expert MLP forward pass
156
  def shared_mlp_forward(
@@ -184,13 +375,13 @@ def shared_mlp_forward(
184
 
185
  # Up projection
186
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
-
188
  # Activation
189
  x = activation_fn(x)
190
-
191
  # Down projection
192
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
-
194
  return x
195
 
196
 
@@ -657,6 +848,7 @@ def parallel_forward_once(
657
  def moe_forward(
658
  x: torch.Tensor,
659
  router_weight: torch.Tensor,
 
660
  moe_top_k: int,
661
  moe_num_experts: int,
662
  moe_jitter_eps: float = None,
@@ -682,6 +874,7 @@ def moe_forward(
682
  logits, expert_weights, expert_indices = route_tokens(
683
  x,
684
  router_weight,
 
685
  moe_top_k,
686
  moe_num_experts,
687
  moe_jitter_eps,
@@ -743,6 +936,7 @@ def moe_forward(
743
  def moe_forward_with_shared_expert(
744
  x: torch.Tensor,
745
  router_weight: torch.Tensor,
 
746
  moe_top_k: int,
747
  moe_num_experts: int,
748
  moe_jitter_eps: float = None,
@@ -775,6 +969,7 @@ def moe_forward_with_shared_expert(
775
  expert_out, expert_weights, router_scores = moe_forward(
776
  x=x,
777
  router_weight=router_weight,
 
778
  moe_top_k=moe_top_k,
779
  moe_num_experts=moe_num_experts,
780
  moe_jitter_eps=moe_jitter_eps,
@@ -795,7 +990,7 @@ def moe_forward_with_shared_expert(
795
  hidden_size=hidden_size,
796
  mlp_impl=mlp_impl,
797
  )
798
-
799
  # If shared expert weights provided, compute shared expert output
800
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
  shared_expert_out = shared_mlp_forward(
@@ -807,7 +1002,7 @@ def moe_forward_with_shared_expert(
807
  activation_fn=shared_activation_fn,
808
  gradient_scale=gradient_scale,
809
  )
810
-
811
  # Combine expert outputs
812
  combined_out = combine_expert_shared_outputs(
813
  shared_expert_out=shared_expert_out,
@@ -815,9 +1010,9 @@ def moe_forward_with_shared_expert(
815
  shared_expert_weighted_sum=shared_expert_weighted_sum,
816
  moe_top_k=moe_top_k,
817
  )
818
-
819
  return combined_out, expert_weights, router_scores
820
-
821
  # Return regular MoE output if no shared expert
822
  return expert_out, expert_weights, router_scores
823
 
@@ -833,7 +1028,7 @@ def create_shared_expert_weights(
833
 
834
  if output_layer_init_method is None:
835
  output_layer_init_method = init_method
836
-
837
  # Create weight tensors
838
  up_proj_weight = torch.empty(
839
  shared_expert_hidden_size,
@@ -847,14 +1042,15 @@ def create_shared_expert_weights(
847
  device=device,
848
  dtype=dtype,
849
  )
850
-
851
  # Initialize weights
852
  init_method(up_proj_weight)
853
  output_layer_init_method(down_proj_weight)
854
-
855
  # No bias by default
856
  return up_proj_weight, down_proj_weight, None, None
857
 
 
858
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
  # This exists because device_mesh is trapped in hook closures with no model attribute
860
  # Fragile - breaks if hook structure changes or Python internals change
@@ -863,14 +1059,21 @@ def get_device_mesh(model):
863
  # Extract device_mesh from child's unused pre_hook closure
864
  try:
865
  # Find the pre-hook that contains 'device_mesh' in its closure
866
- hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
 
 
 
 
867
  # Extract the device_mesh from the closure
868
- return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
 
 
869
  except Exception:
870
  return None
871
 
872
 
873
  class MegaBlocksMoeMLP(torch.nn.Module):
 
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
876
  moe_top_k = getattr(self.router, "top_k", 4)
@@ -879,7 +1082,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
879
  alpha = getattr(self.experts, "alpha", 1.0)
880
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
 
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -887,15 +1092,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
887
  device_mesh = get_device_mesh(self)
888
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
 
890
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
-
893
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
894
  mlp_impl = getattr(self, "mlp_impl", "grouped")
895
-
896
  output, expert_weights_out, *_ = moe_forward(
897
  x=x,
898
  router_weight=self.router.weight,
 
899
  moe_top_k=moe_top_k,
900
  moe_num_experts=moe_num_experts,
901
  moe_jitter_eps=moe_jitter_eps,
@@ -919,8 +1130,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
919
  return output, expert_weights_out
920
 
921
 
 
 
 
 
922
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
-
924
  def __init__(self):
925
  super().__init__()
926
  # Shared expert weights will be set by the user
@@ -930,7 +1145,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
930
  self.shared_down_proj_bias = None
931
  self.shared_expert_weighted_sum = False
932
  self.shared_activation_fn = None
933
-
934
  def set_shared_expert_weights(
935
  self,
936
  up_proj_weight: torch.Tensor,
@@ -946,7 +1161,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
946
  self.shared_down_proj_bias = down_proj_bias
947
  self.shared_expert_weighted_sum = weighted_sum
948
  self.shared_activation_fn = activation_fn
949
-
950
  def forward(self, x: torch.Tensor) -> torch.Tensor:
951
  moe_top_k = getattr(self.router, "top_k", 4)
952
  moe_num_experts = getattr(self.experts, "num_experts", 128)
@@ -954,7 +1169,9 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
954
  alpha = getattr(self.experts, "alpha", 1.0)
955
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
958
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
 
960
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -962,15 +1179,22 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
962
  device_mesh = get_device_mesh(self)
963
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
 
965
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
966
  forward_fn = parallel_forward_once if has_parallel else forward_once
967
-
968
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
969
  mlp_impl = getattr(self, "mlp_impl", "grouped")
970
-
971
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
  x=x,
973
  router_weight=self.router.weight,
 
974
  moe_top_k=moe_top_k,
975
  moe_num_experts=moe_num_experts,
976
  moe_jitter_eps=moe_jitter_eps,
@@ -998,4 +1222,4 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
998
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
  shared_activation_fn=self.shared_activation_fn,
1000
  )
1001
- return output, expert_weights_out
 
1
  import torch
2
  import torch.distributed as dist
3
 
4
+ from typing import Optional, Any, TYPE_CHECKING
5
 
6
  from . import _layers
7
  from . import ops
8
 
9
+ # Conditional import for meta kernel registration
10
+ if TYPE_CHECKING:
11
+
12
+ def register_fake(fn):
13
+ return lambda name: fn
14
+
15
+ else:
16
+ try:
17
+ from torch.library import register_fake
18
+ except ImportError:
19
+ try:
20
+ from torch.library import impl_abstract as register_fake
21
+ except ImportError:
22
+ # Fallback for older PyTorch versions
23
+ def register_fake(op_name):
24
+ def decorator(fn):
25
+ return fn
26
+
27
+ return decorator
28
+
29
+
30
+ # Meta kernel implementations for torch.compile compatibility
31
+ def _install_meta_kernels():
32
+ """Install meta kernels for existing MegaBlocks operations"""
33
+
34
+ # Create wrapper functions that check for compilation and return meta tensors
35
+
36
+ # Patch ops.sort
37
+ if hasattr(ops, "sort"):
38
+ original_sort = ops.sort
39
+
40
+ def sort_with_meta(x, end_bit=None):
41
+ if torch.compiler.is_compiling():
42
+ print("Using meta kernel for sort")
43
+ # Meta implementation - return tensors with correct shape/dtype/device
44
+ return torch.empty_like(x), torch.empty_like(x)
45
+ # print("Using original sort kernel")
46
+ return original_sort(x, end_bit)
47
+
48
+ ops.sort = sort_with_meta
49
+
50
+ # Patch ops.histogram
51
+ if hasattr(ops, "histogram"):
52
+ original_histogram = ops.histogram
53
+
54
+ def histogram_with_meta(x, max_val):
55
+ if torch.compiler.is_compiling():
56
+ # Meta implementation
57
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
58
+ return original_histogram(x, max_val)
59
+
60
+ ops.histogram = histogram_with_meta
61
+
62
+ # Patch ops.inclusive_cumsum
63
+ if hasattr(ops, "inclusive_cumsum"):
64
+ original_inclusive_cumsum = ops.inclusive_cumsum
65
+
66
+ def inclusive_cumsum_with_meta(x, dim):
67
+ if torch.compiler.is_compiling():
68
+ # Meta implementation
69
+ return torch.empty_like(x)
70
+ return original_inclusive_cumsum(x, dim)
71
+
72
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
73
+
74
+ # Patch ops.binned_gather
75
+ if hasattr(ops, "binned_gather"):
76
+ original_binned_gather = ops.binned_gather
77
+
78
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
79
+ if torch.compiler.is_compiling():
80
+ # Meta implementation - output shape based on bin_size
81
+ if x.dim() >= 2:
82
+ hidden_size = x.size(-1)
83
+ return torch.empty(
84
+ (bin_size, x.size(1), hidden_size),
85
+ dtype=x.dtype,
86
+ device=x.device,
87
+ )
88
+ else:
89
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
90
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
91
+
92
+ ops.binned_gather = binned_gather_with_meta
93
+
94
+ # Patch ops.binned_scatter
95
+ if hasattr(ops, "binned_scatter"):
96
+ original_binned_scatter = ops.binned_scatter
97
+
98
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
99
+ if torch.compiler.is_compiling():
100
+ # Meta implementation - typically reduces to 2D
101
+ if x.dim() >= 3:
102
+ return torch.empty(
103
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
104
+ )
105
+ else:
106
+ return torch.empty_like(x)
107
+ return original_binned_scatter(x, indices, weights, bins, top_k)
108
+
109
+ ops.binned_scatter = binned_scatter_with_meta
110
+
111
+ # Patch ops.gather
112
+ if hasattr(ops, "gather"):
113
+ original_gather = ops.gather
114
+
115
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
116
+ if torch.compiler.is_compiling():
117
+ # Meta implementation
118
+ if x.dim() >= 2:
119
+ hidden_size = x.size(-1)
120
+ return torch.empty(
121
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
122
+ )
123
+ else:
124
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
125
+ return original_gather(x, indices, bin_ids, bins, top_k)
126
+
127
+ ops.gather = gather_with_meta
128
+
129
+ # Patch ops.scatter
130
+ if hasattr(ops, "scatter"):
131
+ original_scatter = ops.scatter
132
+
133
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
134
+ if torch.compiler.is_compiling():
135
+ # Meta implementation - restore sequence shape
136
+ seq_len = (
137
+ indices.size(0) // top_k
138
+ if indices.numel() > 0 and top_k > 0
139
+ else x.size(0)
140
+ )
141
+ if x.dim() >= 2:
142
+ return torch.empty(
143
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
144
+ )
145
+ else:
146
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
147
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
148
+
149
+ ops.scatter = scatter_with_meta
150
+
151
+ # Patch ops.replicate
152
+ if hasattr(ops, "replicate"):
153
+ original_replicate = ops.replicate
154
+
155
+ def replicate_with_meta(x, bins, num_outputs):
156
+ if torch.compiler.is_compiling():
157
+ # Meta implementation
158
+ return torch.empty(
159
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
160
+ )
161
+ return original_replicate(x, bins, num_outputs)
162
+
163
+ ops.replicate = replicate_with_meta
164
+
165
+ # Patch ops.repeat (if it's a regular function)
166
+ if hasattr(ops, "repeat"):
167
+ original_repeat = ops.repeat
168
+
169
+ def repeat_with_meta(x, repeats):
170
+ if torch.compiler.is_compiling():
171
+ # Meta implementation
172
+ if isinstance(repeats, (tuple, list)):
173
+ new_shape = list(x.shape)
174
+ for i, rep in enumerate(repeats):
175
+ if i < len(new_shape):
176
+ new_shape[i] *= rep
177
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
178
+ else:
179
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
180
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
181
+ return original_repeat(x, repeats)
182
+
183
+ ops.repeat = repeat_with_meta
184
+
185
+
186
+ # Install meta kernels on import
187
+ try:
188
+ _install_meta_kernels()
189
+ except Exception as e:
190
+ # If meta kernel installation fails, continue without them
191
+ # torch.compile may not work but the library will still function
192
+ import warnings
193
+
194
+ warnings.warn(
195
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
196
+ )
197
+
198
 
199
  # Set the expert model parallel attributes on a tensor
200
  def set_expert_model_parallel_attributes(
 
269
  def route_tokens(
270
  x: torch.Tensor,
271
  router_weight: torch.Tensor,
272
+ router_bias: torch.Tensor,
273
  moe_top_k: int,
274
  moe_num_experts: int,
275
  moe_jitter_eps: float = None,
 
281
  x = apply_jitter(x, moe_jitter_eps)
282
 
283
  x_flat = x.view(-1, x.shape[-1])
284
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
285
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
286
  expert_weights = expert_weights.softmax(dim=-1)
287
  if moe_normalize_expert_weights is not None:
 
319
  w2_bias: torch.Tensor,
320
  gradient_scale: Optional[float] = None,
321
  alpha: float = 1.702,
322
+ limit: float = 7.0,
323
  ):
324
  # Scale weights
325
  w1 = scale_grad(w1, gradient_scale)
 
335
 
336
  # Forward pass
337
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
338
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
339
+ gate = gate.clamp(min=None, max=limit)
340
+ up = up.clamp(min=-limit, max=limit)
341
  glu = gate * torch.sigmoid(gate * alpha)
342
+ next_states = torch.bmm(((up + 1) * glu), w2)
343
+ next_states += w2_bias[..., None, :]
344
+ return next_states
 
345
 
346
  # Shared expert MLP forward pass
347
  def shared_mlp_forward(
 
375
 
376
  # Up projection
377
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
378
+
379
  # Activation
380
  x = activation_fn(x)
381
+
382
  # Down projection
383
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
384
+
385
  return x
386
 
387
 
 
848
  def moe_forward(
849
  x: torch.Tensor,
850
  router_weight: torch.Tensor,
851
+ router_bias: Optional[torch.Tensor],
852
  moe_top_k: int,
853
  moe_num_experts: int,
854
  moe_jitter_eps: float = None,
 
874
  logits, expert_weights, expert_indices = route_tokens(
875
  x,
876
  router_weight,
877
+ router_bias,
878
  moe_top_k,
879
  moe_num_experts,
880
  moe_jitter_eps,
 
936
  def moe_forward_with_shared_expert(
937
  x: torch.Tensor,
938
  router_weight: torch.Tensor,
939
+ router_bias: Optional[torch.Tensor],
940
  moe_top_k: int,
941
  moe_num_experts: int,
942
  moe_jitter_eps: float = None,
 
969
  expert_out, expert_weights, router_scores = moe_forward(
970
  x=x,
971
  router_weight=router_weight,
972
+ router_bias=router_bias,
973
  moe_top_k=moe_top_k,
974
  moe_num_experts=moe_num_experts,
975
  moe_jitter_eps=moe_jitter_eps,
 
990
  hidden_size=hidden_size,
991
  mlp_impl=mlp_impl,
992
  )
993
+
994
  # If shared expert weights provided, compute shared expert output
995
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
996
  shared_expert_out = shared_mlp_forward(
 
1002
  activation_fn=shared_activation_fn,
1003
  gradient_scale=gradient_scale,
1004
  )
1005
+
1006
  # Combine expert outputs
1007
  combined_out = combine_expert_shared_outputs(
1008
  shared_expert_out=shared_expert_out,
 
1010
  shared_expert_weighted_sum=shared_expert_weighted_sum,
1011
  moe_top_k=moe_top_k,
1012
  )
1013
+
1014
  return combined_out, expert_weights, router_scores
1015
+
1016
  # Return regular MoE output if no shared expert
1017
  return expert_out, expert_weights, router_scores
1018
 
 
1028
 
1029
  if output_layer_init_method is None:
1030
  output_layer_init_method = init_method
1031
+
1032
  # Create weight tensors
1033
  up_proj_weight = torch.empty(
1034
  shared_expert_hidden_size,
 
1042
  device=device,
1043
  dtype=dtype,
1044
  )
1045
+
1046
  # Initialize weights
1047
  init_method(up_proj_weight)
1048
  output_layer_init_method(down_proj_weight)
1049
+
1050
  # No bias by default
1051
  return up_proj_weight, down_proj_weight, None, None
1052
 
1053
+
1054
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
1055
  # This exists because device_mesh is trapped in hook closures with no model attribute
1056
  # Fragile - breaks if hook structure changes or Python internals change
 
1059
  # Extract device_mesh from child's unused pre_hook closure
1060
  try:
1061
  # Find the pre-hook that contains 'device_mesh' in its closure
1062
+ hook = next(
1063
+ h
1064
+ for h in model.experts._forward_pre_hooks.values()
1065
+ if "device_mesh" in h.__code__.co_freevars
1066
+ )
1067
  # Extract the device_mesh from the closure
1068
+ return hook.__closure__[
1069
+ hook.__code__.co_freevars.index("device_mesh")
1070
+ ].cell_contents
1071
  except Exception:
1072
  return None
1073
 
1074
 
1075
  class MegaBlocksMoeMLP(torch.nn.Module):
1076
+ can_torch_compile: bool = True
1077
 
1078
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1079
  moe_top_k = getattr(self.router, "top_k", 4)
 
1082
  alpha = getattr(self.experts, "alpha", 1.0)
1083
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1084
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1085
+ moe_normalize_expert_weights = getattr(
1086
+ self.experts, "normalize_expert_weights", None
1087
+ )
1088
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1089
 
1090
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1092
  device_mesh = get_device_mesh(self)
1093
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1094
 
1095
+ has_parallel = (
1096
+ expert_parallel_group is not None
1097
+ and dist.is_initialized()
1098
+ and dist.get_world_size(expert_parallel_group) > 1
1099
+ )
1100
  forward_fn = parallel_forward_once if has_parallel else forward_once
1101
+
1102
+ sort_end_bit = max(
1103
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1104
+ )
1105
  mlp_impl = getattr(self, "mlp_impl", "grouped")
 
1106
  output, expert_weights_out, *_ = moe_forward(
1107
  x=x,
1108
  router_weight=self.router.weight,
1109
+ router_bias=self.router.bias,
1110
  moe_top_k=moe_top_k,
1111
  moe_num_experts=moe_num_experts,
1112
  moe_jitter_eps=moe_jitter_eps,
 
1130
  return output, expert_weights_out
1131
 
1132
 
1133
+ # Export main classes
1134
+ __all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
1135
+
1136
+
1137
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
1138
+
1139
  def __init__(self):
1140
  super().__init__()
1141
  # Shared expert weights will be set by the user
 
1145
  self.shared_down_proj_bias = None
1146
  self.shared_expert_weighted_sum = False
1147
  self.shared_activation_fn = None
1148
+
1149
  def set_shared_expert_weights(
1150
  self,
1151
  up_proj_weight: torch.Tensor,
 
1161
  self.shared_down_proj_bias = down_proj_bias
1162
  self.shared_expert_weighted_sum = weighted_sum
1163
  self.shared_activation_fn = activation_fn
1164
+
1165
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1166
  moe_top_k = getattr(self.router, "top_k", 4)
1167
  moe_num_experts = getattr(self.experts, "num_experts", 128)
 
1169
  alpha = getattr(self.experts, "alpha", 1.0)
1170
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1171
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1172
+ moe_normalize_expert_weights = getattr(
1173
+ self.experts, "normalize_expert_weights", None
1174
+ )
1175
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1176
 
1177
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1179
  device_mesh = get_device_mesh(self)
1180
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1181
 
1182
+ has_parallel = (
1183
+ expert_parallel_group is not None
1184
+ and dist.is_initialized()
1185
+ and dist.get_world_size(expert_parallel_group) > 1
1186
+ )
1187
  forward_fn = parallel_forward_once if has_parallel else forward_once
1188
+
1189
+ sort_end_bit = max(
1190
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1191
+ )
1192
  mlp_impl = getattr(self, "mlp_impl", "grouped")
1193
+
1194
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
1195
  x=x,
1196
  router_weight=self.router.weight,
1197
+ router_bias=self.router.bias,
1198
  moe_top_k=moe_top_k,
1199
  moe_num_experts=moe_num_experts,
1200
  moe_jitter_eps=moe_jitter_eps,
 
1222
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
1223
  shared_activation_fn=self.shared_activation_fn,
1224
  )
1225
+ return output, expert_weights_out
build/torch28-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:289b7b6ebae763cc023856ce1f7179a8c59d6000108311f28d54758a2d3275ad
3
  size 11817960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:534a56ee5f5d1e8c1691a9644dcb42d54cbd8c41f2f29d13ff01674ef50661a7
3
  size 11817960
build/torch28-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_3bdb4b8_dirty
3
- ops = torch.ops._megablocks_3bdb4b8_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_3bdb4b8_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_8176cbe_dirty
3
+ ops = torch.ops._megablocks_8176cbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_8176cbe_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -1,11 +1,200 @@
1
  import torch
2
  import torch.distributed as dist
3
 
4
- from typing import Optional, Any
5
 
6
  from . import _layers
7
  from . import ops
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Set the expert model parallel attributes on a tensor
11
  def set_expert_model_parallel_attributes(
@@ -80,6 +269,7 @@ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
80
  def route_tokens(
81
  x: torch.Tensor,
82
  router_weight: torch.Tensor,
 
83
  moe_top_k: int,
84
  moe_num_experts: int,
85
  moe_jitter_eps: float = None,
@@ -91,7 +281,7 @@ def route_tokens(
91
  x = apply_jitter(x, moe_jitter_eps)
92
 
93
  x_flat = x.view(-1, x.shape[-1])
94
- logits = torch.nn.functional.linear(x_flat, router_weight)
95
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
  expert_weights = expert_weights.softmax(dim=-1)
97
  if moe_normalize_expert_weights is not None:
@@ -129,6 +319,7 @@ def mlp_forward(
129
  w2_bias: torch.Tensor,
130
  gradient_scale: Optional[float] = None,
131
  alpha: float = 1.702,
 
132
  ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
@@ -144,13 +335,13 @@ def mlp_forward(
144
 
145
  # Forward pass
146
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
147
- gate, up = gate_up.chunk(2, dim=-1)
148
-
 
149
  glu = gate * torch.sigmoid(gate * alpha)
150
- x = (up + 1) * glu
151
-
152
- return torch.bmm(x, w2) + w2_bias[..., None, :]
153
-
154
 
155
  # Shared expert MLP forward pass
156
  def shared_mlp_forward(
@@ -184,13 +375,13 @@ def shared_mlp_forward(
184
 
185
  # Up projection
186
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
-
188
  # Activation
189
  x = activation_fn(x)
190
-
191
  # Down projection
192
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
-
194
  return x
195
 
196
 
@@ -657,6 +848,7 @@ def parallel_forward_once(
657
  def moe_forward(
658
  x: torch.Tensor,
659
  router_weight: torch.Tensor,
 
660
  moe_top_k: int,
661
  moe_num_experts: int,
662
  moe_jitter_eps: float = None,
@@ -682,6 +874,7 @@ def moe_forward(
682
  logits, expert_weights, expert_indices = route_tokens(
683
  x,
684
  router_weight,
 
685
  moe_top_k,
686
  moe_num_experts,
687
  moe_jitter_eps,
@@ -743,6 +936,7 @@ def moe_forward(
743
  def moe_forward_with_shared_expert(
744
  x: torch.Tensor,
745
  router_weight: torch.Tensor,
 
746
  moe_top_k: int,
747
  moe_num_experts: int,
748
  moe_jitter_eps: float = None,
@@ -775,6 +969,7 @@ def moe_forward_with_shared_expert(
775
  expert_out, expert_weights, router_scores = moe_forward(
776
  x=x,
777
  router_weight=router_weight,
 
778
  moe_top_k=moe_top_k,
779
  moe_num_experts=moe_num_experts,
780
  moe_jitter_eps=moe_jitter_eps,
@@ -795,7 +990,7 @@ def moe_forward_with_shared_expert(
795
  hidden_size=hidden_size,
796
  mlp_impl=mlp_impl,
797
  )
798
-
799
  # If shared expert weights provided, compute shared expert output
800
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
  shared_expert_out = shared_mlp_forward(
@@ -807,7 +1002,7 @@ def moe_forward_with_shared_expert(
807
  activation_fn=shared_activation_fn,
808
  gradient_scale=gradient_scale,
809
  )
810
-
811
  # Combine expert outputs
812
  combined_out = combine_expert_shared_outputs(
813
  shared_expert_out=shared_expert_out,
@@ -815,9 +1010,9 @@ def moe_forward_with_shared_expert(
815
  shared_expert_weighted_sum=shared_expert_weighted_sum,
816
  moe_top_k=moe_top_k,
817
  )
818
-
819
  return combined_out, expert_weights, router_scores
820
-
821
  # Return regular MoE output if no shared expert
822
  return expert_out, expert_weights, router_scores
823
 
@@ -833,7 +1028,7 @@ def create_shared_expert_weights(
833
 
834
  if output_layer_init_method is None:
835
  output_layer_init_method = init_method
836
-
837
  # Create weight tensors
838
  up_proj_weight = torch.empty(
839
  shared_expert_hidden_size,
@@ -847,14 +1042,15 @@ def create_shared_expert_weights(
847
  device=device,
848
  dtype=dtype,
849
  )
850
-
851
  # Initialize weights
852
  init_method(up_proj_weight)
853
  output_layer_init_method(down_proj_weight)
854
-
855
  # No bias by default
856
  return up_proj_weight, down_proj_weight, None, None
857
 
 
858
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
  # This exists because device_mesh is trapped in hook closures with no model attribute
860
  # Fragile - breaks if hook structure changes or Python internals change
@@ -863,14 +1059,21 @@ def get_device_mesh(model):
863
  # Extract device_mesh from child's unused pre_hook closure
864
  try:
865
  # Find the pre-hook that contains 'device_mesh' in its closure
866
- hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
 
 
 
 
867
  # Extract the device_mesh from the closure
868
- return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
 
 
869
  except Exception:
870
  return None
871
 
872
 
873
  class MegaBlocksMoeMLP(torch.nn.Module):
 
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
876
  moe_top_k = getattr(self.router, "top_k", 4)
@@ -879,7 +1082,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
879
  alpha = getattr(self.experts, "alpha", 1.0)
880
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
 
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -887,15 +1092,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
887
  device_mesh = get_device_mesh(self)
888
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
 
890
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
-
893
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
894
  mlp_impl = getattr(self, "mlp_impl", "grouped")
895
-
896
  output, expert_weights_out, *_ = moe_forward(
897
  x=x,
898
  router_weight=self.router.weight,
 
899
  moe_top_k=moe_top_k,
900
  moe_num_experts=moe_num_experts,
901
  moe_jitter_eps=moe_jitter_eps,
@@ -919,8 +1130,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
919
  return output, expert_weights_out
920
 
921
 
 
 
 
 
922
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
-
924
  def __init__(self):
925
  super().__init__()
926
  # Shared expert weights will be set by the user
@@ -930,7 +1145,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
930
  self.shared_down_proj_bias = None
931
  self.shared_expert_weighted_sum = False
932
  self.shared_activation_fn = None
933
-
934
  def set_shared_expert_weights(
935
  self,
936
  up_proj_weight: torch.Tensor,
@@ -946,7 +1161,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
946
  self.shared_down_proj_bias = down_proj_bias
947
  self.shared_expert_weighted_sum = weighted_sum
948
  self.shared_activation_fn = activation_fn
949
-
950
  def forward(self, x: torch.Tensor) -> torch.Tensor:
951
  moe_top_k = getattr(self.router, "top_k", 4)
952
  moe_num_experts = getattr(self.experts, "num_experts", 128)
@@ -954,7 +1169,9 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
954
  alpha = getattr(self.experts, "alpha", 1.0)
955
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
958
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
 
960
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -962,15 +1179,22 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
962
  device_mesh = get_device_mesh(self)
963
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
 
965
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
966
  forward_fn = parallel_forward_once if has_parallel else forward_once
967
-
968
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
969
  mlp_impl = getattr(self, "mlp_impl", "grouped")
970
-
971
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
  x=x,
973
  router_weight=self.router.weight,
 
974
  moe_top_k=moe_top_k,
975
  moe_num_experts=moe_num_experts,
976
  moe_jitter_eps=moe_jitter_eps,
@@ -998,4 +1222,4 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
998
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
  shared_activation_fn=self.shared_activation_fn,
1000
  )
1001
- return output, expert_weights_out
 
1
  import torch
2
  import torch.distributed as dist
3
 
4
+ from typing import Optional, Any, TYPE_CHECKING
5
 
6
  from . import _layers
7
  from . import ops
8
 
9
+ # Conditional import for meta kernel registration
10
+ if TYPE_CHECKING:
11
+
12
+ def register_fake(fn):
13
+ return lambda name: fn
14
+
15
+ else:
16
+ try:
17
+ from torch.library import register_fake
18
+ except ImportError:
19
+ try:
20
+ from torch.library import impl_abstract as register_fake
21
+ except ImportError:
22
+ # Fallback for older PyTorch versions
23
+ def register_fake(op_name):
24
+ def decorator(fn):
25
+ return fn
26
+
27
+ return decorator
28
+
29
+
30
+ # Meta kernel implementations for torch.compile compatibility
31
+ def _install_meta_kernels():
32
+ """Install meta kernels for existing MegaBlocks operations"""
33
+
34
+ # Create wrapper functions that check for compilation and return meta tensors
35
+
36
+ # Patch ops.sort
37
+ if hasattr(ops, "sort"):
38
+ original_sort = ops.sort
39
+
40
+ def sort_with_meta(x, end_bit=None):
41
+ if torch.compiler.is_compiling():
42
+ print("Using meta kernel for sort")
43
+ # Meta implementation - return tensors with correct shape/dtype/device
44
+ return torch.empty_like(x), torch.empty_like(x)
45
+ # print("Using original sort kernel")
46
+ return original_sort(x, end_bit)
47
+
48
+ ops.sort = sort_with_meta
49
+
50
+ # Patch ops.histogram
51
+ if hasattr(ops, "histogram"):
52
+ original_histogram = ops.histogram
53
+
54
+ def histogram_with_meta(x, max_val):
55
+ if torch.compiler.is_compiling():
56
+ # Meta implementation
57
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
58
+ return original_histogram(x, max_val)
59
+
60
+ ops.histogram = histogram_with_meta
61
+
62
+ # Patch ops.inclusive_cumsum
63
+ if hasattr(ops, "inclusive_cumsum"):
64
+ original_inclusive_cumsum = ops.inclusive_cumsum
65
+
66
+ def inclusive_cumsum_with_meta(x, dim):
67
+ if torch.compiler.is_compiling():
68
+ # Meta implementation
69
+ return torch.empty_like(x)
70
+ return original_inclusive_cumsum(x, dim)
71
+
72
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
73
+
74
+ # Patch ops.binned_gather
75
+ if hasattr(ops, "binned_gather"):
76
+ original_binned_gather = ops.binned_gather
77
+
78
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
79
+ if torch.compiler.is_compiling():
80
+ # Meta implementation - output shape based on bin_size
81
+ if x.dim() >= 2:
82
+ hidden_size = x.size(-1)
83
+ return torch.empty(
84
+ (bin_size, x.size(1), hidden_size),
85
+ dtype=x.dtype,
86
+ device=x.device,
87
+ )
88
+ else:
89
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
90
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
91
+
92
+ ops.binned_gather = binned_gather_with_meta
93
+
94
+ # Patch ops.binned_scatter
95
+ if hasattr(ops, "binned_scatter"):
96
+ original_binned_scatter = ops.binned_scatter
97
+
98
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
99
+ if torch.compiler.is_compiling():
100
+ # Meta implementation - typically reduces to 2D
101
+ if x.dim() >= 3:
102
+ return torch.empty(
103
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
104
+ )
105
+ else:
106
+ return torch.empty_like(x)
107
+ return original_binned_scatter(x, indices, weights, bins, top_k)
108
+
109
+ ops.binned_scatter = binned_scatter_with_meta
110
+
111
+ # Patch ops.gather
112
+ if hasattr(ops, "gather"):
113
+ original_gather = ops.gather
114
+
115
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
116
+ if torch.compiler.is_compiling():
117
+ # Meta implementation
118
+ if x.dim() >= 2:
119
+ hidden_size = x.size(-1)
120
+ return torch.empty(
121
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
122
+ )
123
+ else:
124
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
125
+ return original_gather(x, indices, bin_ids, bins, top_k)
126
+
127
+ ops.gather = gather_with_meta
128
+
129
+ # Patch ops.scatter
130
+ if hasattr(ops, "scatter"):
131
+ original_scatter = ops.scatter
132
+
133
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
134
+ if torch.compiler.is_compiling():
135
+ # Meta implementation - restore sequence shape
136
+ seq_len = (
137
+ indices.size(0) // top_k
138
+ if indices.numel() > 0 and top_k > 0
139
+ else x.size(0)
140
+ )
141
+ if x.dim() >= 2:
142
+ return torch.empty(
143
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
144
+ )
145
+ else:
146
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
147
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
148
+
149
+ ops.scatter = scatter_with_meta
150
+
151
+ # Patch ops.replicate
152
+ if hasattr(ops, "replicate"):
153
+ original_replicate = ops.replicate
154
+
155
+ def replicate_with_meta(x, bins, num_outputs):
156
+ if torch.compiler.is_compiling():
157
+ # Meta implementation
158
+ return torch.empty(
159
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
160
+ )
161
+ return original_replicate(x, bins, num_outputs)
162
+
163
+ ops.replicate = replicate_with_meta
164
+
165
+ # Patch ops.repeat (if it's a regular function)
166
+ if hasattr(ops, "repeat"):
167
+ original_repeat = ops.repeat
168
+
169
+ def repeat_with_meta(x, repeats):
170
+ if torch.compiler.is_compiling():
171
+ # Meta implementation
172
+ if isinstance(repeats, (tuple, list)):
173
+ new_shape = list(x.shape)
174
+ for i, rep in enumerate(repeats):
175
+ if i < len(new_shape):
176
+ new_shape[i] *= rep
177
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
178
+ else:
179
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
180
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
181
+ return original_repeat(x, repeats)
182
+
183
+ ops.repeat = repeat_with_meta
184
+
185
+
186
+ # Install meta kernels on import
187
+ try:
188
+ _install_meta_kernels()
189
+ except Exception as e:
190
+ # If meta kernel installation fails, continue without them
191
+ # torch.compile may not work but the library will still function
192
+ import warnings
193
+
194
+ warnings.warn(
195
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
196
+ )
197
+
198
 
199
  # Set the expert model parallel attributes on a tensor
200
  def set_expert_model_parallel_attributes(
 
269
  def route_tokens(
270
  x: torch.Tensor,
271
  router_weight: torch.Tensor,
272
+ router_bias: torch.Tensor,
273
  moe_top_k: int,
274
  moe_num_experts: int,
275
  moe_jitter_eps: float = None,
 
281
  x = apply_jitter(x, moe_jitter_eps)
282
 
283
  x_flat = x.view(-1, x.shape[-1])
284
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
285
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
286
  expert_weights = expert_weights.softmax(dim=-1)
287
  if moe_normalize_expert_weights is not None:
 
319
  w2_bias: torch.Tensor,
320
  gradient_scale: Optional[float] = None,
321
  alpha: float = 1.702,
322
+ limit: float = 7.0,
323
  ):
324
  # Scale weights
325
  w1 = scale_grad(w1, gradient_scale)
 
335
 
336
  # Forward pass
337
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
338
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
339
+ gate = gate.clamp(min=None, max=limit)
340
+ up = up.clamp(min=-limit, max=limit)
341
  glu = gate * torch.sigmoid(gate * alpha)
342
+ next_states = torch.bmm(((up + 1) * glu), w2)
343
+ next_states += w2_bias[..., None, :]
344
+ return next_states
 
345
 
346
  # Shared expert MLP forward pass
347
  def shared_mlp_forward(
 
375
 
376
  # Up projection
377
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
378
+
379
  # Activation
380
  x = activation_fn(x)
381
+
382
  # Down projection
383
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
384
+
385
  return x
386
 
387
 
 
848
  def moe_forward(
849
  x: torch.Tensor,
850
  router_weight: torch.Tensor,
851
+ router_bias: Optional[torch.Tensor],
852
  moe_top_k: int,
853
  moe_num_experts: int,
854
  moe_jitter_eps: float = None,
 
874
  logits, expert_weights, expert_indices = route_tokens(
875
  x,
876
  router_weight,
877
+ router_bias,
878
  moe_top_k,
879
  moe_num_experts,
880
  moe_jitter_eps,
 
936
  def moe_forward_with_shared_expert(
937
  x: torch.Tensor,
938
  router_weight: torch.Tensor,
939
+ router_bias: Optional[torch.Tensor],
940
  moe_top_k: int,
941
  moe_num_experts: int,
942
  moe_jitter_eps: float = None,
 
969
  expert_out, expert_weights, router_scores = moe_forward(
970
  x=x,
971
  router_weight=router_weight,
972
+ router_bias=router_bias,
973
  moe_top_k=moe_top_k,
974
  moe_num_experts=moe_num_experts,
975
  moe_jitter_eps=moe_jitter_eps,
 
990
  hidden_size=hidden_size,
991
  mlp_impl=mlp_impl,
992
  )
993
+
994
  # If shared expert weights provided, compute shared expert output
995
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
996
  shared_expert_out = shared_mlp_forward(
 
1002
  activation_fn=shared_activation_fn,
1003
  gradient_scale=gradient_scale,
1004
  )
1005
+
1006
  # Combine expert outputs
1007
  combined_out = combine_expert_shared_outputs(
1008
  shared_expert_out=shared_expert_out,
 
1010
  shared_expert_weighted_sum=shared_expert_weighted_sum,
1011
  moe_top_k=moe_top_k,
1012
  )
1013
+
1014
  return combined_out, expert_weights, router_scores
1015
+
1016
  # Return regular MoE output if no shared expert
1017
  return expert_out, expert_weights, router_scores
1018
 
 
1028
 
1029
  if output_layer_init_method is None:
1030
  output_layer_init_method = init_method
1031
+
1032
  # Create weight tensors
1033
  up_proj_weight = torch.empty(
1034
  shared_expert_hidden_size,
 
1042
  device=device,
1043
  dtype=dtype,
1044
  )
1045
+
1046
  # Initialize weights
1047
  init_method(up_proj_weight)
1048
  output_layer_init_method(down_proj_weight)
1049
+
1050
  # No bias by default
1051
  return up_proj_weight, down_proj_weight, None, None
1052
 
1053
+
1054
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
1055
  # This exists because device_mesh is trapped in hook closures with no model attribute
1056
  # Fragile - breaks if hook structure changes or Python internals change
 
1059
  # Extract device_mesh from child's unused pre_hook closure
1060
  try:
1061
  # Find the pre-hook that contains 'device_mesh' in its closure
1062
+ hook = next(
1063
+ h
1064
+ for h in model.experts._forward_pre_hooks.values()
1065
+ if "device_mesh" in h.__code__.co_freevars
1066
+ )
1067
  # Extract the device_mesh from the closure
1068
+ return hook.__closure__[
1069
+ hook.__code__.co_freevars.index("device_mesh")
1070
+ ].cell_contents
1071
  except Exception:
1072
  return None
1073
 
1074
 
1075
  class MegaBlocksMoeMLP(torch.nn.Module):
1076
+ can_torch_compile: bool = True
1077
 
1078
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1079
  moe_top_k = getattr(self.router, "top_k", 4)
 
1082
  alpha = getattr(self.experts, "alpha", 1.0)
1083
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1084
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1085
+ moe_normalize_expert_weights = getattr(
1086
+ self.experts, "normalize_expert_weights", None
1087
+ )
1088
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1089
 
1090
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1092
  device_mesh = get_device_mesh(self)
1093
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1094
 
1095
+ has_parallel = (
1096
+ expert_parallel_group is not None
1097
+ and dist.is_initialized()
1098
+ and dist.get_world_size(expert_parallel_group) > 1
1099
+ )
1100
  forward_fn = parallel_forward_once if has_parallel else forward_once
1101
+
1102
+ sort_end_bit = max(
1103
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1104
+ )
1105
  mlp_impl = getattr(self, "mlp_impl", "grouped")
 
1106
  output, expert_weights_out, *_ = moe_forward(
1107
  x=x,
1108
  router_weight=self.router.weight,
1109
+ router_bias=self.router.bias,
1110
  moe_top_k=moe_top_k,
1111
  moe_num_experts=moe_num_experts,
1112
  moe_jitter_eps=moe_jitter_eps,
 
1130
  return output, expert_weights_out
1131
 
1132
 
1133
+ # Export main classes
1134
+ __all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
1135
+
1136
+
1137
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
1138
+
1139
  def __init__(self):
1140
  super().__init__()
1141
  # Shared expert weights will be set by the user
 
1145
  self.shared_down_proj_bias = None
1146
  self.shared_expert_weighted_sum = False
1147
  self.shared_activation_fn = None
1148
+
1149
  def set_shared_expert_weights(
1150
  self,
1151
  up_proj_weight: torch.Tensor,
 
1161
  self.shared_down_proj_bias = down_proj_bias
1162
  self.shared_expert_weighted_sum = weighted_sum
1163
  self.shared_activation_fn = activation_fn
1164
+
1165
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1166
  moe_top_k = getattr(self.router, "top_k", 4)
1167
  moe_num_experts = getattr(self.experts, "num_experts", 128)
 
1169
  alpha = getattr(self.experts, "alpha", 1.0)
1170
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1171
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1172
+ moe_normalize_expert_weights = getattr(
1173
+ self.experts, "normalize_expert_weights", None
1174
+ )
1175
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1176
 
1177
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1179
  device_mesh = get_device_mesh(self)
1180
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1181
 
1182
+ has_parallel = (
1183
+ expert_parallel_group is not None
1184
+ and dist.is_initialized()
1185
+ and dist.get_world_size(expert_parallel_group) > 1
1186
+ )
1187
  forward_fn = parallel_forward_once if has_parallel else forward_once
1188
+
1189
+ sort_end_bit = max(
1190
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1191
+ )
1192
  mlp_impl = getattr(self, "mlp_impl", "grouped")
1193
+
1194
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
1195
  x=x,
1196
  router_weight=self.router.weight,
1197
+ router_bias=self.router.bias,
1198
  moe_top_k=moe_top_k,
1199
  moe_num_experts=moe_num_experts,
1200
  moe_jitter_eps=moe_jitter_eps,
 
1222
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
1223
  shared_activation_fn=self.shared_activation_fn,
1224
  )
1225
+ return output, expert_weights_out
build/torch28-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:af3fadbfa9afb6d0315a5dcf57d4a2dbc043e2529dab9d3a6d0656fac4142211
3
  size 17770912
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c075d3398cea481296a0a8b47444fc49a4c984dac991abf5bd3577bc9edf1b71
3
  size 17770912
build/torch28-cxx11-cu128-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_3bdb4b8_dirty
3
- ops = torch.ops._megablocks_3bdb4b8_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_3bdb4b8_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_8176cbe_dirty
3
+ ops = torch.ops._megablocks_8176cbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_8176cbe_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py CHANGED
@@ -1,11 +1,200 @@
1
  import torch
2
  import torch.distributed as dist
3
 
4
- from typing import Optional, Any
5
 
6
  from . import _layers
7
  from . import ops
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Set the expert model parallel attributes on a tensor
11
  def set_expert_model_parallel_attributes(
@@ -80,6 +269,7 @@ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
80
  def route_tokens(
81
  x: torch.Tensor,
82
  router_weight: torch.Tensor,
 
83
  moe_top_k: int,
84
  moe_num_experts: int,
85
  moe_jitter_eps: float = None,
@@ -91,7 +281,7 @@ def route_tokens(
91
  x = apply_jitter(x, moe_jitter_eps)
92
 
93
  x_flat = x.view(-1, x.shape[-1])
94
- logits = torch.nn.functional.linear(x_flat, router_weight)
95
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
  expert_weights = expert_weights.softmax(dim=-1)
97
  if moe_normalize_expert_weights is not None:
@@ -129,6 +319,7 @@ def mlp_forward(
129
  w2_bias: torch.Tensor,
130
  gradient_scale: Optional[float] = None,
131
  alpha: float = 1.702,
 
132
  ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
@@ -144,13 +335,13 @@ def mlp_forward(
144
 
145
  # Forward pass
146
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
147
- gate, up = gate_up.chunk(2, dim=-1)
148
-
 
149
  glu = gate * torch.sigmoid(gate * alpha)
150
- x = (up + 1) * glu
151
-
152
- return torch.bmm(x, w2) + w2_bias[..., None, :]
153
-
154
 
155
  # Shared expert MLP forward pass
156
  def shared_mlp_forward(
@@ -184,13 +375,13 @@ def shared_mlp_forward(
184
 
185
  # Up projection
186
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
-
188
  # Activation
189
  x = activation_fn(x)
190
-
191
  # Down projection
192
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
-
194
  return x
195
 
196
 
@@ -657,6 +848,7 @@ def parallel_forward_once(
657
  def moe_forward(
658
  x: torch.Tensor,
659
  router_weight: torch.Tensor,
 
660
  moe_top_k: int,
661
  moe_num_experts: int,
662
  moe_jitter_eps: float = None,
@@ -682,6 +874,7 @@ def moe_forward(
682
  logits, expert_weights, expert_indices = route_tokens(
683
  x,
684
  router_weight,
 
685
  moe_top_k,
686
  moe_num_experts,
687
  moe_jitter_eps,
@@ -743,6 +936,7 @@ def moe_forward(
743
  def moe_forward_with_shared_expert(
744
  x: torch.Tensor,
745
  router_weight: torch.Tensor,
 
746
  moe_top_k: int,
747
  moe_num_experts: int,
748
  moe_jitter_eps: float = None,
@@ -775,6 +969,7 @@ def moe_forward_with_shared_expert(
775
  expert_out, expert_weights, router_scores = moe_forward(
776
  x=x,
777
  router_weight=router_weight,
 
778
  moe_top_k=moe_top_k,
779
  moe_num_experts=moe_num_experts,
780
  moe_jitter_eps=moe_jitter_eps,
@@ -795,7 +990,7 @@ def moe_forward_with_shared_expert(
795
  hidden_size=hidden_size,
796
  mlp_impl=mlp_impl,
797
  )
798
-
799
  # If shared expert weights provided, compute shared expert output
800
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
  shared_expert_out = shared_mlp_forward(
@@ -807,7 +1002,7 @@ def moe_forward_with_shared_expert(
807
  activation_fn=shared_activation_fn,
808
  gradient_scale=gradient_scale,
809
  )
810
-
811
  # Combine expert outputs
812
  combined_out = combine_expert_shared_outputs(
813
  shared_expert_out=shared_expert_out,
@@ -815,9 +1010,9 @@ def moe_forward_with_shared_expert(
815
  shared_expert_weighted_sum=shared_expert_weighted_sum,
816
  moe_top_k=moe_top_k,
817
  )
818
-
819
  return combined_out, expert_weights, router_scores
820
-
821
  # Return regular MoE output if no shared expert
822
  return expert_out, expert_weights, router_scores
823
 
@@ -833,7 +1028,7 @@ def create_shared_expert_weights(
833
 
834
  if output_layer_init_method is None:
835
  output_layer_init_method = init_method
836
-
837
  # Create weight tensors
838
  up_proj_weight = torch.empty(
839
  shared_expert_hidden_size,
@@ -847,14 +1042,15 @@ def create_shared_expert_weights(
847
  device=device,
848
  dtype=dtype,
849
  )
850
-
851
  # Initialize weights
852
  init_method(up_proj_weight)
853
  output_layer_init_method(down_proj_weight)
854
-
855
  # No bias by default
856
  return up_proj_weight, down_proj_weight, None, None
857
 
 
858
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
  # This exists because device_mesh is trapped in hook closures with no model attribute
860
  # Fragile - breaks if hook structure changes or Python internals change
@@ -863,14 +1059,21 @@ def get_device_mesh(model):
863
  # Extract device_mesh from child's unused pre_hook closure
864
  try:
865
  # Find the pre-hook that contains 'device_mesh' in its closure
866
- hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
 
 
 
 
867
  # Extract the device_mesh from the closure
868
- return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
 
 
869
  except Exception:
870
  return None
871
 
872
 
873
  class MegaBlocksMoeMLP(torch.nn.Module):
 
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
876
  moe_top_k = getattr(self.router, "top_k", 4)
@@ -879,7 +1082,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
879
  alpha = getattr(self.experts, "alpha", 1.0)
880
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
 
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -887,15 +1092,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
887
  device_mesh = get_device_mesh(self)
888
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
 
890
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
-
893
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
894
  mlp_impl = getattr(self, "mlp_impl", "grouped")
895
-
896
  output, expert_weights_out, *_ = moe_forward(
897
  x=x,
898
  router_weight=self.router.weight,
 
899
  moe_top_k=moe_top_k,
900
  moe_num_experts=moe_num_experts,
901
  moe_jitter_eps=moe_jitter_eps,
@@ -919,8 +1130,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
919
  return output, expert_weights_out
920
 
921
 
 
 
 
 
922
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
-
924
  def __init__(self):
925
  super().__init__()
926
  # Shared expert weights will be set by the user
@@ -930,7 +1145,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
930
  self.shared_down_proj_bias = None
931
  self.shared_expert_weighted_sum = False
932
  self.shared_activation_fn = None
933
-
934
  def set_shared_expert_weights(
935
  self,
936
  up_proj_weight: torch.Tensor,
@@ -946,7 +1161,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
946
  self.shared_down_proj_bias = down_proj_bias
947
  self.shared_expert_weighted_sum = weighted_sum
948
  self.shared_activation_fn = activation_fn
949
-
950
  def forward(self, x: torch.Tensor) -> torch.Tensor:
951
  moe_top_k = getattr(self.router, "top_k", 4)
952
  moe_num_experts = getattr(self.experts, "num_experts", 128)
@@ -954,7 +1169,9 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
954
  alpha = getattr(self.experts, "alpha", 1.0)
955
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
958
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
 
960
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -962,15 +1179,22 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
962
  device_mesh = get_device_mesh(self)
963
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
 
965
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
966
  forward_fn = parallel_forward_once if has_parallel else forward_once
967
-
968
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
969
  mlp_impl = getattr(self, "mlp_impl", "grouped")
970
-
971
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
  x=x,
973
  router_weight=self.router.weight,
 
974
  moe_top_k=moe_top_k,
975
  moe_num_experts=moe_num_experts,
976
  moe_jitter_eps=moe_jitter_eps,
@@ -998,4 +1222,4 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
998
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
  shared_activation_fn=self.shared_activation_fn,
1000
  )
1001
- return output, expert_weights_out
 
1
  import torch
2
  import torch.distributed as dist
3
 
4
+ from typing import Optional, Any, TYPE_CHECKING
5
 
6
  from . import _layers
7
  from . import ops
8
 
9
+ # Conditional import for meta kernel registration
10
+ if TYPE_CHECKING:
11
+
12
+ def register_fake(fn):
13
+ return lambda name: fn
14
+
15
+ else:
16
+ try:
17
+ from torch.library import register_fake
18
+ except ImportError:
19
+ try:
20
+ from torch.library import impl_abstract as register_fake
21
+ except ImportError:
22
+ # Fallback for older PyTorch versions
23
+ def register_fake(op_name):
24
+ def decorator(fn):
25
+ return fn
26
+
27
+ return decorator
28
+
29
+
30
+ # Meta kernel implementations for torch.compile compatibility
31
+ def _install_meta_kernels():
32
+ """Install meta kernels for existing MegaBlocks operations"""
33
+
34
+ # Create wrapper functions that check for compilation and return meta tensors
35
+
36
+ # Patch ops.sort
37
+ if hasattr(ops, "sort"):
38
+ original_sort = ops.sort
39
+
40
+ def sort_with_meta(x, end_bit=None):
41
+ if torch.compiler.is_compiling():
42
+ print("Using meta kernel for sort")
43
+ # Meta implementation - return tensors with correct shape/dtype/device
44
+ return torch.empty_like(x), torch.empty_like(x)
45
+ # print("Using original sort kernel")
46
+ return original_sort(x, end_bit)
47
+
48
+ ops.sort = sort_with_meta
49
+
50
+ # Patch ops.histogram
51
+ if hasattr(ops, "histogram"):
52
+ original_histogram = ops.histogram
53
+
54
+ def histogram_with_meta(x, max_val):
55
+ if torch.compiler.is_compiling():
56
+ # Meta implementation
57
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
58
+ return original_histogram(x, max_val)
59
+
60
+ ops.histogram = histogram_with_meta
61
+
62
+ # Patch ops.inclusive_cumsum
63
+ if hasattr(ops, "inclusive_cumsum"):
64
+ original_inclusive_cumsum = ops.inclusive_cumsum
65
+
66
+ def inclusive_cumsum_with_meta(x, dim):
67
+ if torch.compiler.is_compiling():
68
+ # Meta implementation
69
+ return torch.empty_like(x)
70
+ return original_inclusive_cumsum(x, dim)
71
+
72
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
73
+
74
+ # Patch ops.binned_gather
75
+ if hasattr(ops, "binned_gather"):
76
+ original_binned_gather = ops.binned_gather
77
+
78
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
79
+ if torch.compiler.is_compiling():
80
+ # Meta implementation - output shape based on bin_size
81
+ if x.dim() >= 2:
82
+ hidden_size = x.size(-1)
83
+ return torch.empty(
84
+ (bin_size, x.size(1), hidden_size),
85
+ dtype=x.dtype,
86
+ device=x.device,
87
+ )
88
+ else:
89
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
90
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
91
+
92
+ ops.binned_gather = binned_gather_with_meta
93
+
94
+ # Patch ops.binned_scatter
95
+ if hasattr(ops, "binned_scatter"):
96
+ original_binned_scatter = ops.binned_scatter
97
+
98
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
99
+ if torch.compiler.is_compiling():
100
+ # Meta implementation - typically reduces to 2D
101
+ if x.dim() >= 3:
102
+ return torch.empty(
103
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
104
+ )
105
+ else:
106
+ return torch.empty_like(x)
107
+ return original_binned_scatter(x, indices, weights, bins, top_k)
108
+
109
+ ops.binned_scatter = binned_scatter_with_meta
110
+
111
+ # Patch ops.gather
112
+ if hasattr(ops, "gather"):
113
+ original_gather = ops.gather
114
+
115
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
116
+ if torch.compiler.is_compiling():
117
+ # Meta implementation
118
+ if x.dim() >= 2:
119
+ hidden_size = x.size(-1)
120
+ return torch.empty(
121
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
122
+ )
123
+ else:
124
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
125
+ return original_gather(x, indices, bin_ids, bins, top_k)
126
+
127
+ ops.gather = gather_with_meta
128
+
129
+ # Patch ops.scatter
130
+ if hasattr(ops, "scatter"):
131
+ original_scatter = ops.scatter
132
+
133
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
134
+ if torch.compiler.is_compiling():
135
+ # Meta implementation - restore sequence shape
136
+ seq_len = (
137
+ indices.size(0) // top_k
138
+ if indices.numel() > 0 and top_k > 0
139
+ else x.size(0)
140
+ )
141
+ if x.dim() >= 2:
142
+ return torch.empty(
143
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
144
+ )
145
+ else:
146
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
147
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
148
+
149
+ ops.scatter = scatter_with_meta
150
+
151
+ # Patch ops.replicate
152
+ if hasattr(ops, "replicate"):
153
+ original_replicate = ops.replicate
154
+
155
+ def replicate_with_meta(x, bins, num_outputs):
156
+ if torch.compiler.is_compiling():
157
+ # Meta implementation
158
+ return torch.empty(
159
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
160
+ )
161
+ return original_replicate(x, bins, num_outputs)
162
+
163
+ ops.replicate = replicate_with_meta
164
+
165
+ # Patch ops.repeat (if it's a regular function)
166
+ if hasattr(ops, "repeat"):
167
+ original_repeat = ops.repeat
168
+
169
+ def repeat_with_meta(x, repeats):
170
+ if torch.compiler.is_compiling():
171
+ # Meta implementation
172
+ if isinstance(repeats, (tuple, list)):
173
+ new_shape = list(x.shape)
174
+ for i, rep in enumerate(repeats):
175
+ if i < len(new_shape):
176
+ new_shape[i] *= rep
177
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
178
+ else:
179
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
180
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
181
+ return original_repeat(x, repeats)
182
+
183
+ ops.repeat = repeat_with_meta
184
+
185
+
186
+ # Install meta kernels on import
187
+ try:
188
+ _install_meta_kernels()
189
+ except Exception as e:
190
+ # If meta kernel installation fails, continue without them
191
+ # torch.compile may not work but the library will still function
192
+ import warnings
193
+
194
+ warnings.warn(
195
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
196
+ )
197
+
198
 
199
  # Set the expert model parallel attributes on a tensor
200
  def set_expert_model_parallel_attributes(
 
269
  def route_tokens(
270
  x: torch.Tensor,
271
  router_weight: torch.Tensor,
272
+ router_bias: torch.Tensor,
273
  moe_top_k: int,
274
  moe_num_experts: int,
275
  moe_jitter_eps: float = None,
 
281
  x = apply_jitter(x, moe_jitter_eps)
282
 
283
  x_flat = x.view(-1, x.shape[-1])
284
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
285
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
286
  expert_weights = expert_weights.softmax(dim=-1)
287
  if moe_normalize_expert_weights is not None:
 
319
  w2_bias: torch.Tensor,
320
  gradient_scale: Optional[float] = None,
321
  alpha: float = 1.702,
322
+ limit: float = 7.0,
323
  ):
324
  # Scale weights
325
  w1 = scale_grad(w1, gradient_scale)
 
335
 
336
  # Forward pass
337
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
338
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
339
+ gate = gate.clamp(min=None, max=limit)
340
+ up = up.clamp(min=-limit, max=limit)
341
  glu = gate * torch.sigmoid(gate * alpha)
342
+ next_states = torch.bmm(((up + 1) * glu), w2)
343
+ next_states += w2_bias[..., None, :]
344
+ return next_states
 
345
 
346
  # Shared expert MLP forward pass
347
  def shared_mlp_forward(
 
375
 
376
  # Up projection
377
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
378
+
379
  # Activation
380
  x = activation_fn(x)
381
+
382
  # Down projection
383
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
384
+
385
  return x
386
 
387
 
 
848
  def moe_forward(
849
  x: torch.Tensor,
850
  router_weight: torch.Tensor,
851
+ router_bias: Optional[torch.Tensor],
852
  moe_top_k: int,
853
  moe_num_experts: int,
854
  moe_jitter_eps: float = None,
 
874
  logits, expert_weights, expert_indices = route_tokens(
875
  x,
876
  router_weight,
877
+ router_bias,
878
  moe_top_k,
879
  moe_num_experts,
880
  moe_jitter_eps,
 
936
  def moe_forward_with_shared_expert(
937
  x: torch.Tensor,
938
  router_weight: torch.Tensor,
939
+ router_bias: Optional[torch.Tensor],
940
  moe_top_k: int,
941
  moe_num_experts: int,
942
  moe_jitter_eps: float = None,
 
969
  expert_out, expert_weights, router_scores = moe_forward(
970
  x=x,
971
  router_weight=router_weight,
972
+ router_bias=router_bias,
973
  moe_top_k=moe_top_k,
974
  moe_num_experts=moe_num_experts,
975
  moe_jitter_eps=moe_jitter_eps,
 
990
  hidden_size=hidden_size,
991
  mlp_impl=mlp_impl,
992
  )
993
+
994
  # If shared expert weights provided, compute shared expert output
995
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
996
  shared_expert_out = shared_mlp_forward(
 
1002
  activation_fn=shared_activation_fn,
1003
  gradient_scale=gradient_scale,
1004
  )
1005
+
1006
  # Combine expert outputs
1007
  combined_out = combine_expert_shared_outputs(
1008
  shared_expert_out=shared_expert_out,
 
1010
  shared_expert_weighted_sum=shared_expert_weighted_sum,
1011
  moe_top_k=moe_top_k,
1012
  )
1013
+
1014
  return combined_out, expert_weights, router_scores
1015
+
1016
  # Return regular MoE output if no shared expert
1017
  return expert_out, expert_weights, router_scores
1018
 
 
1028
 
1029
  if output_layer_init_method is None:
1030
  output_layer_init_method = init_method
1031
+
1032
  # Create weight tensors
1033
  up_proj_weight = torch.empty(
1034
  shared_expert_hidden_size,
 
1042
  device=device,
1043
  dtype=dtype,
1044
  )
1045
+
1046
  # Initialize weights
1047
  init_method(up_proj_weight)
1048
  output_layer_init_method(down_proj_weight)
1049
+
1050
  # No bias by default
1051
  return up_proj_weight, down_proj_weight, None, None
1052
 
1053
+
1054
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
1055
  # This exists because device_mesh is trapped in hook closures with no model attribute
1056
  # Fragile - breaks if hook structure changes or Python internals change
 
1059
  # Extract device_mesh from child's unused pre_hook closure
1060
  try:
1061
  # Find the pre-hook that contains 'device_mesh' in its closure
1062
+ hook = next(
1063
+ h
1064
+ for h in model.experts._forward_pre_hooks.values()
1065
+ if "device_mesh" in h.__code__.co_freevars
1066
+ )
1067
  # Extract the device_mesh from the closure
1068
+ return hook.__closure__[
1069
+ hook.__code__.co_freevars.index("device_mesh")
1070
+ ].cell_contents
1071
  except Exception:
1072
  return None
1073
 
1074
 
1075
  class MegaBlocksMoeMLP(torch.nn.Module):
1076
+ can_torch_compile: bool = True
1077
 
1078
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1079
  moe_top_k = getattr(self.router, "top_k", 4)
 
1082
  alpha = getattr(self.experts, "alpha", 1.0)
1083
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1084
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1085
+ moe_normalize_expert_weights = getattr(
1086
+ self.experts, "normalize_expert_weights", None
1087
+ )
1088
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1089
 
1090
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1092
  device_mesh = get_device_mesh(self)
1093
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1094
 
1095
+ has_parallel = (
1096
+ expert_parallel_group is not None
1097
+ and dist.is_initialized()
1098
+ and dist.get_world_size(expert_parallel_group) > 1
1099
+ )
1100
  forward_fn = parallel_forward_once if has_parallel else forward_once
1101
+
1102
+ sort_end_bit = max(
1103
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1104
+ )
1105
  mlp_impl = getattr(self, "mlp_impl", "grouped")
 
1106
  output, expert_weights_out, *_ = moe_forward(
1107
  x=x,
1108
  router_weight=self.router.weight,
1109
+ router_bias=self.router.bias,
1110
  moe_top_k=moe_top_k,
1111
  moe_num_experts=moe_num_experts,
1112
  moe_jitter_eps=moe_jitter_eps,
 
1130
  return output, expert_weights_out
1131
 
1132
 
1133
+ # Export main classes
1134
+ __all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
1135
+
1136
+
1137
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
1138
+
1139
  def __init__(self):
1140
  super().__init__()
1141
  # Shared expert weights will be set by the user
 
1145
  self.shared_down_proj_bias = None
1146
  self.shared_expert_weighted_sum = False
1147
  self.shared_activation_fn = None
1148
+
1149
  def set_shared_expert_weights(
1150
  self,
1151
  up_proj_weight: torch.Tensor,
 
1161
  self.shared_down_proj_bias = down_proj_bias
1162
  self.shared_expert_weighted_sum = weighted_sum
1163
  self.shared_activation_fn = activation_fn
1164
+
1165
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1166
  moe_top_k = getattr(self.router, "top_k", 4)
1167
  moe_num_experts = getattr(self.experts, "num_experts", 128)
 
1169
  alpha = getattr(self.experts, "alpha", 1.0)
1170
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1171
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1172
+ moe_normalize_expert_weights = getattr(
1173
+ self.experts, "normalize_expert_weights", None
1174
+ )
1175
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1176
 
1177
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1179
  device_mesh = get_device_mesh(self)
1180
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1181
 
1182
+ has_parallel = (
1183
+ expert_parallel_group is not None
1184
+ and dist.is_initialized()
1185
+ and dist.get_world_size(expert_parallel_group) > 1
1186
+ )
1187
  forward_fn = parallel_forward_once if has_parallel else forward_once
1188
+
1189
+ sort_end_bit = max(
1190
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1191
+ )
1192
  mlp_impl = getattr(self, "mlp_impl", "grouped")
1193
+
1194
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
1195
  x=x,
1196
  router_weight=self.router.weight,
1197
+ router_bias=self.router.bias,
1198
  moe_top_k=moe_top_k,
1199
  moe_num_experts=moe_num_experts,
1200
  moe_jitter_eps=moe_jitter_eps,
 
1222
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
1223
  shared_activation_fn=self.shared_activation_fn,
1224
  )
1225
+ return output, expert_weights_out
build/torch28-cxx11-cu129-x86_64-linux/megablocks/{_megablocks_3bdb4b8_dirty.abi3.so β†’ _megablocks_8176cbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c2535eb4cc58df36b630efb47058f3d277e0f9ddeb0591156f620d73be6d848
3
  size 13585072
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b84b6d64ceb3ef6f5cca709cdca1ec8c79c500b6bbb636c003a7d72fb58e6acf
3
  size 13585072
build/torch28-cxx11-cu129-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_3bdb4b8_dirty
3
- ops = torch.ops._megablocks_3bdb4b8_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_3bdb4b8_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_8176cbe_dirty
3
+ ops = torch.ops._megablocks_8176cbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_8176cbe_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/megablocks/layers.py CHANGED
@@ -1,11 +1,200 @@
1
  import torch
2
  import torch.distributed as dist
3
 
4
- from typing import Optional, Any
5
 
6
  from . import _layers
7
  from . import ops
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Set the expert model parallel attributes on a tensor
11
  def set_expert_model_parallel_attributes(
@@ -80,6 +269,7 @@ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
80
  def route_tokens(
81
  x: torch.Tensor,
82
  router_weight: torch.Tensor,
 
83
  moe_top_k: int,
84
  moe_num_experts: int,
85
  moe_jitter_eps: float = None,
@@ -91,7 +281,7 @@ def route_tokens(
91
  x = apply_jitter(x, moe_jitter_eps)
92
 
93
  x_flat = x.view(-1, x.shape[-1])
94
- logits = torch.nn.functional.linear(x_flat, router_weight)
95
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
  expert_weights = expert_weights.softmax(dim=-1)
97
  if moe_normalize_expert_weights is not None:
@@ -129,6 +319,7 @@ def mlp_forward(
129
  w2_bias: torch.Tensor,
130
  gradient_scale: Optional[float] = None,
131
  alpha: float = 1.702,
 
132
  ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
@@ -144,13 +335,13 @@ def mlp_forward(
144
 
145
  # Forward pass
146
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
147
- gate, up = gate_up.chunk(2, dim=-1)
148
-
 
149
  glu = gate * torch.sigmoid(gate * alpha)
150
- x = (up + 1) * glu
151
-
152
- return torch.bmm(x, w2) + w2_bias[..., None, :]
153
-
154
 
155
  # Shared expert MLP forward pass
156
  def shared_mlp_forward(
@@ -184,13 +375,13 @@ def shared_mlp_forward(
184
 
185
  # Up projection
186
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
-
188
  # Activation
189
  x = activation_fn(x)
190
-
191
  # Down projection
192
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
-
194
  return x
195
 
196
 
@@ -657,6 +848,7 @@ def parallel_forward_once(
657
  def moe_forward(
658
  x: torch.Tensor,
659
  router_weight: torch.Tensor,
 
660
  moe_top_k: int,
661
  moe_num_experts: int,
662
  moe_jitter_eps: float = None,
@@ -682,6 +874,7 @@ def moe_forward(
682
  logits, expert_weights, expert_indices = route_tokens(
683
  x,
684
  router_weight,
 
685
  moe_top_k,
686
  moe_num_experts,
687
  moe_jitter_eps,
@@ -743,6 +936,7 @@ def moe_forward(
743
  def moe_forward_with_shared_expert(
744
  x: torch.Tensor,
745
  router_weight: torch.Tensor,
 
746
  moe_top_k: int,
747
  moe_num_experts: int,
748
  moe_jitter_eps: float = None,
@@ -775,6 +969,7 @@ def moe_forward_with_shared_expert(
775
  expert_out, expert_weights, router_scores = moe_forward(
776
  x=x,
777
  router_weight=router_weight,
 
778
  moe_top_k=moe_top_k,
779
  moe_num_experts=moe_num_experts,
780
  moe_jitter_eps=moe_jitter_eps,
@@ -795,7 +990,7 @@ def moe_forward_with_shared_expert(
795
  hidden_size=hidden_size,
796
  mlp_impl=mlp_impl,
797
  )
798
-
799
  # If shared expert weights provided, compute shared expert output
800
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
  shared_expert_out = shared_mlp_forward(
@@ -807,7 +1002,7 @@ def moe_forward_with_shared_expert(
807
  activation_fn=shared_activation_fn,
808
  gradient_scale=gradient_scale,
809
  )
810
-
811
  # Combine expert outputs
812
  combined_out = combine_expert_shared_outputs(
813
  shared_expert_out=shared_expert_out,
@@ -815,9 +1010,9 @@ def moe_forward_with_shared_expert(
815
  shared_expert_weighted_sum=shared_expert_weighted_sum,
816
  moe_top_k=moe_top_k,
817
  )
818
-
819
  return combined_out, expert_weights, router_scores
820
-
821
  # Return regular MoE output if no shared expert
822
  return expert_out, expert_weights, router_scores
823
 
@@ -833,7 +1028,7 @@ def create_shared_expert_weights(
833
 
834
  if output_layer_init_method is None:
835
  output_layer_init_method = init_method
836
-
837
  # Create weight tensors
838
  up_proj_weight = torch.empty(
839
  shared_expert_hidden_size,
@@ -847,14 +1042,15 @@ def create_shared_expert_weights(
847
  device=device,
848
  dtype=dtype,
849
  )
850
-
851
  # Initialize weights
852
  init_method(up_proj_weight)
853
  output_layer_init_method(down_proj_weight)
854
-
855
  # No bias by default
856
  return up_proj_weight, down_proj_weight, None, None
857
 
 
858
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
  # This exists because device_mesh is trapped in hook closures with no model attribute
860
  # Fragile - breaks if hook structure changes or Python internals change
@@ -863,14 +1059,21 @@ def get_device_mesh(model):
863
  # Extract device_mesh from child's unused pre_hook closure
864
  try:
865
  # Find the pre-hook that contains 'device_mesh' in its closure
866
- hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
 
 
 
 
867
  # Extract the device_mesh from the closure
868
- return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
 
 
869
  except Exception:
870
  return None
871
 
872
 
873
  class MegaBlocksMoeMLP(torch.nn.Module):
 
874
 
875
  def forward(self, x: torch.Tensor) -> torch.Tensor:
876
  moe_top_k = getattr(self.router, "top_k", 4)
@@ -879,7 +1082,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
879
  alpha = getattr(self.experts, "alpha", 1.0)
880
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
 
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -887,15 +1092,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
887
  device_mesh = get_device_mesh(self)
888
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
 
890
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
-
893
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
894
  mlp_impl = getattr(self, "mlp_impl", "grouped")
895
-
896
  output, expert_weights_out, *_ = moe_forward(
897
  x=x,
898
  router_weight=self.router.weight,
 
899
  moe_top_k=moe_top_k,
900
  moe_num_experts=moe_num_experts,
901
  moe_jitter_eps=moe_jitter_eps,
@@ -919,8 +1130,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
919
  return output, expert_weights_out
920
 
921
 
 
 
 
 
922
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
-
924
  def __init__(self):
925
  super().__init__()
926
  # Shared expert weights will be set by the user
@@ -930,7 +1145,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
930
  self.shared_down_proj_bias = None
931
  self.shared_expert_weighted_sum = False
932
  self.shared_activation_fn = None
933
-
934
  def set_shared_expert_weights(
935
  self,
936
  up_proj_weight: torch.Tensor,
@@ -946,7 +1161,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
946
  self.shared_down_proj_bias = down_proj_bias
947
  self.shared_expert_weighted_sum = weighted_sum
948
  self.shared_activation_fn = activation_fn
949
-
950
  def forward(self, x: torch.Tensor) -> torch.Tensor:
951
  moe_top_k = getattr(self.router, "top_k", 4)
952
  moe_num_experts = getattr(self.experts, "num_experts", 128)
@@ -954,7 +1169,9 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
954
  alpha = getattr(self.experts, "alpha", 1.0)
955
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
958
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
 
960
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -962,15 +1179,22 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
962
  device_mesh = get_device_mesh(self)
963
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
 
965
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
966
  forward_fn = parallel_forward_once if has_parallel else forward_once
967
-
968
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
969
  mlp_impl = getattr(self, "mlp_impl", "grouped")
970
-
971
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
  x=x,
973
  router_weight=self.router.weight,
 
974
  moe_top_k=moe_top_k,
975
  moe_num_experts=moe_num_experts,
976
  moe_jitter_eps=moe_jitter_eps,
@@ -998,4 +1222,4 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
998
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
  shared_activation_fn=self.shared_activation_fn,
1000
  )
1001
- return output, expert_weights_out
 
1
  import torch
2
  import torch.distributed as dist
3
 
4
+ from typing import Optional, Any, TYPE_CHECKING
5
 
6
  from . import _layers
7
  from . import ops
8
 
9
+ # Conditional import for meta kernel registration
10
+ if TYPE_CHECKING:
11
+
12
+ def register_fake(fn):
13
+ return lambda name: fn
14
+
15
+ else:
16
+ try:
17
+ from torch.library import register_fake
18
+ except ImportError:
19
+ try:
20
+ from torch.library import impl_abstract as register_fake
21
+ except ImportError:
22
+ # Fallback for older PyTorch versions
23
+ def register_fake(op_name):
24
+ def decorator(fn):
25
+ return fn
26
+
27
+ return decorator
28
+
29
+
30
+ # Meta kernel implementations for torch.compile compatibility
31
+ def _install_meta_kernels():
32
+ """Install meta kernels for existing MegaBlocks operations"""
33
+
34
+ # Create wrapper functions that check for compilation and return meta tensors
35
+
36
+ # Patch ops.sort
37
+ if hasattr(ops, "sort"):
38
+ original_sort = ops.sort
39
+
40
+ def sort_with_meta(x, end_bit=None):
41
+ if torch.compiler.is_compiling():
42
+ print("Using meta kernel for sort")
43
+ # Meta implementation - return tensors with correct shape/dtype/device
44
+ return torch.empty_like(x), torch.empty_like(x)
45
+ # print("Using original sort kernel")
46
+ return original_sort(x, end_bit)
47
+
48
+ ops.sort = sort_with_meta
49
+
50
+ # Patch ops.histogram
51
+ if hasattr(ops, "histogram"):
52
+ original_histogram = ops.histogram
53
+
54
+ def histogram_with_meta(x, max_val):
55
+ if torch.compiler.is_compiling():
56
+ # Meta implementation
57
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
58
+ return original_histogram(x, max_val)
59
+
60
+ ops.histogram = histogram_with_meta
61
+
62
+ # Patch ops.inclusive_cumsum
63
+ if hasattr(ops, "inclusive_cumsum"):
64
+ original_inclusive_cumsum = ops.inclusive_cumsum
65
+
66
+ def inclusive_cumsum_with_meta(x, dim):
67
+ if torch.compiler.is_compiling():
68
+ # Meta implementation
69
+ return torch.empty_like(x)
70
+ return original_inclusive_cumsum(x, dim)
71
+
72
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
73
+
74
+ # Patch ops.binned_gather
75
+ if hasattr(ops, "binned_gather"):
76
+ original_binned_gather = ops.binned_gather
77
+
78
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
79
+ if torch.compiler.is_compiling():
80
+ # Meta implementation - output shape based on bin_size
81
+ if x.dim() >= 2:
82
+ hidden_size = x.size(-1)
83
+ return torch.empty(
84
+ (bin_size, x.size(1), hidden_size),
85
+ dtype=x.dtype,
86
+ device=x.device,
87
+ )
88
+ else:
89
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
90
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
91
+
92
+ ops.binned_gather = binned_gather_with_meta
93
+
94
+ # Patch ops.binned_scatter
95
+ if hasattr(ops, "binned_scatter"):
96
+ original_binned_scatter = ops.binned_scatter
97
+
98
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
99
+ if torch.compiler.is_compiling():
100
+ # Meta implementation - typically reduces to 2D
101
+ if x.dim() >= 3:
102
+ return torch.empty(
103
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
104
+ )
105
+ else:
106
+ return torch.empty_like(x)
107
+ return original_binned_scatter(x, indices, weights, bins, top_k)
108
+
109
+ ops.binned_scatter = binned_scatter_with_meta
110
+
111
+ # Patch ops.gather
112
+ if hasattr(ops, "gather"):
113
+ original_gather = ops.gather
114
+
115
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
116
+ if torch.compiler.is_compiling():
117
+ # Meta implementation
118
+ if x.dim() >= 2:
119
+ hidden_size = x.size(-1)
120
+ return torch.empty(
121
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
122
+ )
123
+ else:
124
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
125
+ return original_gather(x, indices, bin_ids, bins, top_k)
126
+
127
+ ops.gather = gather_with_meta
128
+
129
+ # Patch ops.scatter
130
+ if hasattr(ops, "scatter"):
131
+ original_scatter = ops.scatter
132
+
133
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
134
+ if torch.compiler.is_compiling():
135
+ # Meta implementation - restore sequence shape
136
+ seq_len = (
137
+ indices.size(0) // top_k
138
+ if indices.numel() > 0 and top_k > 0
139
+ else x.size(0)
140
+ )
141
+ if x.dim() >= 2:
142
+ return torch.empty(
143
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
144
+ )
145
+ else:
146
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
147
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
148
+
149
+ ops.scatter = scatter_with_meta
150
+
151
+ # Patch ops.replicate
152
+ if hasattr(ops, "replicate"):
153
+ original_replicate = ops.replicate
154
+
155
+ def replicate_with_meta(x, bins, num_outputs):
156
+ if torch.compiler.is_compiling():
157
+ # Meta implementation
158
+ return torch.empty(
159
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
160
+ )
161
+ return original_replicate(x, bins, num_outputs)
162
+
163
+ ops.replicate = replicate_with_meta
164
+
165
+ # Patch ops.repeat (if it's a regular function)
166
+ if hasattr(ops, "repeat"):
167
+ original_repeat = ops.repeat
168
+
169
+ def repeat_with_meta(x, repeats):
170
+ if torch.compiler.is_compiling():
171
+ # Meta implementation
172
+ if isinstance(repeats, (tuple, list)):
173
+ new_shape = list(x.shape)
174
+ for i, rep in enumerate(repeats):
175
+ if i < len(new_shape):
176
+ new_shape[i] *= rep
177
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
178
+ else:
179
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
180
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
181
+ return original_repeat(x, repeats)
182
+
183
+ ops.repeat = repeat_with_meta
184
+
185
+
186
+ # Install meta kernels on import
187
+ try:
188
+ _install_meta_kernels()
189
+ except Exception as e:
190
+ # If meta kernel installation fails, continue without them
191
+ # torch.compile may not work but the library will still function
192
+ import warnings
193
+
194
+ warnings.warn(
195
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
196
+ )
197
+
198
 
199
  # Set the expert model parallel attributes on a tensor
200
  def set_expert_model_parallel_attributes(
 
269
  def route_tokens(
270
  x: torch.Tensor,
271
  router_weight: torch.Tensor,
272
+ router_bias: torch.Tensor,
273
  moe_top_k: int,
274
  moe_num_experts: int,
275
  moe_jitter_eps: float = None,
 
281
  x = apply_jitter(x, moe_jitter_eps)
282
 
283
  x_flat = x.view(-1, x.shape[-1])
284
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
285
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
286
  expert_weights = expert_weights.softmax(dim=-1)
287
  if moe_normalize_expert_weights is not None:
 
319
  w2_bias: torch.Tensor,
320
  gradient_scale: Optional[float] = None,
321
  alpha: float = 1.702,
322
+ limit: float = 7.0,
323
  ):
324
  # Scale weights
325
  w1 = scale_grad(w1, gradient_scale)
 
335
 
336
  # Forward pass
337
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
338
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
339
+ gate = gate.clamp(min=None, max=limit)
340
+ up = up.clamp(min=-limit, max=limit)
341
  glu = gate * torch.sigmoid(gate * alpha)
342
+ next_states = torch.bmm(((up + 1) * glu), w2)
343
+ next_states += w2_bias[..., None, :]
344
+ return next_states
 
345
 
346
  # Shared expert MLP forward pass
347
  def shared_mlp_forward(
 
375
 
376
  # Up projection
377
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
378
+
379
  # Activation
380
  x = activation_fn(x)
381
+
382
  # Down projection
383
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
384
+
385
  return x
386
 
387
 
 
848
  def moe_forward(
849
  x: torch.Tensor,
850
  router_weight: torch.Tensor,
851
+ router_bias: Optional[torch.Tensor],
852
  moe_top_k: int,
853
  moe_num_experts: int,
854
  moe_jitter_eps: float = None,
 
874
  logits, expert_weights, expert_indices = route_tokens(
875
  x,
876
  router_weight,
877
+ router_bias,
878
  moe_top_k,
879
  moe_num_experts,
880
  moe_jitter_eps,
 
936
  def moe_forward_with_shared_expert(
937
  x: torch.Tensor,
938
  router_weight: torch.Tensor,
939
+ router_bias: Optional[torch.Tensor],
940
  moe_top_k: int,
941
  moe_num_experts: int,
942
  moe_jitter_eps: float = None,
 
969
  expert_out, expert_weights, router_scores = moe_forward(
970
  x=x,
971
  router_weight=router_weight,
972
+ router_bias=router_bias,
973
  moe_top_k=moe_top_k,
974
  moe_num_experts=moe_num_experts,
975
  moe_jitter_eps=moe_jitter_eps,
 
990
  hidden_size=hidden_size,
991
  mlp_impl=mlp_impl,
992
  )
993
+
994
  # If shared expert weights provided, compute shared expert output
995
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
996
  shared_expert_out = shared_mlp_forward(
 
1002
  activation_fn=shared_activation_fn,
1003
  gradient_scale=gradient_scale,
1004
  )
1005
+
1006
  # Combine expert outputs
1007
  combined_out = combine_expert_shared_outputs(
1008
  shared_expert_out=shared_expert_out,
 
1010
  shared_expert_weighted_sum=shared_expert_weighted_sum,
1011
  moe_top_k=moe_top_k,
1012
  )
1013
+
1014
  return combined_out, expert_weights, router_scores
1015
+
1016
  # Return regular MoE output if no shared expert
1017
  return expert_out, expert_weights, router_scores
1018
 
 
1028
 
1029
  if output_layer_init_method is None:
1030
  output_layer_init_method = init_method
1031
+
1032
  # Create weight tensors
1033
  up_proj_weight = torch.empty(
1034
  shared_expert_hidden_size,
 
1042
  device=device,
1043
  dtype=dtype,
1044
  )
1045
+
1046
  # Initialize weights
1047
  init_method(up_proj_weight)
1048
  output_layer_init_method(down_proj_weight)
1049
+
1050
  # No bias by default
1051
  return up_proj_weight, down_proj_weight, None, None
1052
 
1053
+
1054
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
1055
  # This exists because device_mesh is trapped in hook closures with no model attribute
1056
  # Fragile - breaks if hook structure changes or Python internals change
 
1059
  # Extract device_mesh from child's unused pre_hook closure
1060
  try:
1061
  # Find the pre-hook that contains 'device_mesh' in its closure
1062
+ hook = next(
1063
+ h
1064
+ for h in model.experts._forward_pre_hooks.values()
1065
+ if "device_mesh" in h.__code__.co_freevars
1066
+ )
1067
  # Extract the device_mesh from the closure
1068
+ return hook.__closure__[
1069
+ hook.__code__.co_freevars.index("device_mesh")
1070
+ ].cell_contents
1071
  except Exception:
1072
  return None
1073
 
1074
 
1075
  class MegaBlocksMoeMLP(torch.nn.Module):
1076
+ can_torch_compile: bool = True
1077
 
1078
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1079
  moe_top_k = getattr(self.router, "top_k", 4)
 
1082
  alpha = getattr(self.experts, "alpha", 1.0)
1083
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1084
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1085
+ moe_normalize_expert_weights = getattr(
1086
+ self.experts, "normalize_expert_weights", None
1087
+ )
1088
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1089
 
1090
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1092
  device_mesh = get_device_mesh(self)
1093
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1094
 
1095
+ has_parallel = (
1096
+ expert_parallel_group is not None
1097
+ and dist.is_initialized()
1098
+ and dist.get_world_size(expert_parallel_group) > 1
1099
+ )
1100
  forward_fn = parallel_forward_once if has_parallel else forward_once
1101
+
1102
+ sort_end_bit = max(
1103
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1104
+ )
1105
  mlp_impl = getattr(self, "mlp_impl", "grouped")
 
1106
  output, expert_weights_out, *_ = moe_forward(
1107
  x=x,
1108
  router_weight=self.router.weight,
1109
+ router_bias=self.router.bias,
1110
  moe_top_k=moe_top_k,
1111
  moe_num_experts=moe_num_experts,
1112
  moe_jitter_eps=moe_jitter_eps,
 
1130
  return output, expert_weights_out
1131
 
1132
 
1133
+ # Export main classes
1134
+ __all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
1135
+
1136
+
1137
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
1138
+
1139
  def __init__(self):
1140
  super().__init__()
1141
  # Shared expert weights will be set by the user
 
1145
  self.shared_down_proj_bias = None
1146
  self.shared_expert_weighted_sum = False
1147
  self.shared_activation_fn = None
1148
+
1149
  def set_shared_expert_weights(
1150
  self,
1151
  up_proj_weight: torch.Tensor,
 
1161
  self.shared_down_proj_bias = down_proj_bias
1162
  self.shared_expert_weighted_sum = weighted_sum
1163
  self.shared_activation_fn = activation_fn
1164
+
1165
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1166
  moe_top_k = getattr(self.router, "top_k", 4)
1167
  moe_num_experts = getattr(self.experts, "num_experts", 128)
 
1169
  alpha = getattr(self.experts, "alpha", 1.0)
1170
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1171
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1172
+ moe_normalize_expert_weights = getattr(
1173
+ self.experts, "normalize_expert_weights", None
1174
+ )
1175
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1176
 
1177
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1179
  device_mesh = get_device_mesh(self)
1180
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1181
 
1182
+ has_parallel = (
1183
+ expert_parallel_group is not None
1184
+ and dist.is_initialized()
1185
+ and dist.get_world_size(expert_parallel_group) > 1
1186
+ )
1187
  forward_fn = parallel_forward_once if has_parallel else forward_once
1188
+
1189
+ sort_end_bit = max(
1190
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1191
+ )
1192
  mlp_impl = getattr(self, "mlp_impl", "grouped")
1193
+
1194
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
1195
  x=x,
1196
  router_weight=self.router.weight,
1197
+ router_bias=self.router.bias,
1198
  moe_top_k=moe_top_k,
1199
  moe_num_experts=moe_num_experts,
1200
  moe_jitter_eps=moe_jitter_eps,
 
1222
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
1223
  shared_activation_fn=self.shared_activation_fn,
1224
  )
1225
+ return output, expert_weights_out