drbh commited on
Commit
cd5b9c4
·
1 Parent(s): e0fb143

fix: support torch compile via fake tensors

Browse files
Files changed (1) hide show
  1. torch-ext/megablocks/layers.py +257 -34
torch-ext/megablocks/layers.py CHANGED
@@ -1,11 +1,200 @@
1
  import torch
2
  import torch.distributed as dist
3
 
4
- from typing import Optional, Any
5
 
6
  from . import _layers
7
  from . import ops
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Set the expert model parallel attributes on a tensor
11
  def set_expert_model_parallel_attributes(
@@ -80,6 +269,7 @@ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
80
  def route_tokens(
81
  x: torch.Tensor,
82
  router_weight: torch.Tensor,
 
83
  moe_top_k: int,
84
  moe_num_experts: int,
85
  moe_jitter_eps: float = None,
@@ -91,7 +281,7 @@ def route_tokens(
91
  x = apply_jitter(x, moe_jitter_eps)
92
 
93
  x_flat = x.view(-1, x.shape[-1])
94
- logits = torch.nn.functional.linear(x_flat, router_weight)
95
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
  expert_weights = expert_weights.softmax(dim=-1)
97
  if moe_normalize_expert_weights is not None:
@@ -129,6 +319,7 @@ def mlp_forward(
129
  w2_bias: torch.Tensor,
130
  gradient_scale: Optional[float] = None,
131
  alpha: float = 1.702,
 
132
  ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
@@ -144,13 +335,13 @@ def mlp_forward(
144
 
145
  # Forward pass
146
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
147
- gate, up = gate_up.chunk(2, dim=-1)
148
-
 
149
  glu = gate * torch.sigmoid(gate * alpha)
150
- x = (up + 1) * glu
151
-
152
- return torch.bmm(x, w2) + w2_bias[..., None, :]
153
-
154
 
155
  # Shared expert MLP forward pass
156
  def shared_mlp_forward(
@@ -184,13 +375,13 @@ def shared_mlp_forward(
184
 
185
  # Up projection
186
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
-
188
  # Activation
189
  x = activation_fn(x)
190
-
191
  # Down projection
192
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
-
194
  return x
195
 
196
 
@@ -657,6 +848,7 @@ def parallel_forward_once(
657
  def moe_forward(
658
  x: torch.Tensor,
659
  router_weight: torch.Tensor,
 
660
  moe_top_k: int,
661
  moe_num_experts: int,
662
  moe_jitter_eps: float = None,
@@ -682,6 +874,7 @@ def moe_forward(
682
  logits, expert_weights, expert_indices = route_tokens(
683
  x,
684
  router_weight,
 
685
  moe_top_k,
686
  moe_num_experts,
687
  moe_jitter_eps,
@@ -743,6 +936,7 @@ def moe_forward(
743
  def moe_forward_with_shared_expert(
744
  x: torch.Tensor,
745
  router_weight: torch.Tensor,
 
746
  moe_top_k: int,
747
  moe_num_experts: int,
748
  moe_jitter_eps: float = None,
@@ -775,6 +969,7 @@ def moe_forward_with_shared_expert(
775
  expert_out, expert_weights, router_scores = moe_forward(
776
  x=x,
777
  router_weight=router_weight,
 
778
  moe_top_k=moe_top_k,
779
  moe_num_experts=moe_num_experts,
780
  moe_jitter_eps=moe_jitter_eps,
@@ -795,7 +990,7 @@ def moe_forward_with_shared_expert(
795
  hidden_size=hidden_size,
796
  mlp_impl=mlp_impl,
797
  )
798
-
799
  # If shared expert weights provided, compute shared expert output
800
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
  shared_expert_out = shared_mlp_forward(
@@ -807,7 +1002,7 @@ def moe_forward_with_shared_expert(
807
  activation_fn=shared_activation_fn,
808
  gradient_scale=gradient_scale,
809
  )
810
-
811
  # Combine expert outputs
812
  combined_out = combine_expert_shared_outputs(
813
  shared_expert_out=shared_expert_out,
@@ -815,9 +1010,9 @@ def moe_forward_with_shared_expert(
815
  shared_expert_weighted_sum=shared_expert_weighted_sum,
816
  moe_top_k=moe_top_k,
817
  )
818
-
819
  return combined_out, expert_weights, router_scores
820
-
821
  # Return regular MoE output if no shared expert
822
  return expert_out, expert_weights, router_scores
823
 
@@ -833,7 +1028,7 @@ def create_shared_expert_weights(
833
 
834
  if output_layer_init_method is None:
835
  output_layer_init_method = init_method
836
-
837
  # Create weight tensors
838
  up_proj_weight = torch.empty(
839
  shared_expert_hidden_size,
@@ -847,14 +1042,15 @@ def create_shared_expert_weights(
847
  device=device,
848
  dtype=dtype,
849
  )
850
-
851
  # Initialize weights
852
  init_method(up_proj_weight)
853
  output_layer_init_method(down_proj_weight)
854
-
855
  # No bias by default
856
  return up_proj_weight, down_proj_weight, None, None
857
 
 
858
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
  # This exists because device_mesh is trapped in hook closures with no model attribute
860
  # Fragile - breaks if hook structure changes or Python internals change
@@ -863,9 +1059,15 @@ def get_device_mesh(model):
863
  # Extract device_mesh from child's unused pre_hook closure
864
  try:
865
  # Find the pre-hook that contains 'device_mesh' in its closure
866
- hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
 
 
 
 
867
  # Extract the device_mesh from the closure
868
- return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
 
 
869
  except Exception:
870
  return None
871
 
@@ -879,7 +1081,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
879
  alpha = getattr(self.experts, "alpha", 1.0)
880
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
881
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
883
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
 
885
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -887,15 +1091,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
887
  device_mesh = get_device_mesh(self)
888
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
 
890
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
891
  forward_fn = parallel_forward_once if has_parallel else forward_once
892
-
893
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
894
  mlp_impl = getattr(self, "mlp_impl", "grouped")
895
-
896
  output, expert_weights_out, *_ = moe_forward(
897
  x=x,
898
  router_weight=self.router.weight,
 
899
  moe_top_k=moe_top_k,
900
  moe_num_experts=moe_num_experts,
901
  moe_jitter_eps=moe_jitter_eps,
@@ -919,8 +1129,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
919
  return output, expert_weights_out
920
 
921
 
 
 
 
 
922
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
-
924
  def __init__(self):
925
  super().__init__()
926
  # Shared expert weights will be set by the user
@@ -930,7 +1144,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
930
  self.shared_down_proj_bias = None
931
  self.shared_expert_weighted_sum = False
932
  self.shared_activation_fn = None
933
-
934
  def set_shared_expert_weights(
935
  self,
936
  up_proj_weight: torch.Tensor,
@@ -946,7 +1160,7 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
946
  self.shared_down_proj_bias = down_proj_bias
947
  self.shared_expert_weighted_sum = weighted_sum
948
  self.shared_activation_fn = activation_fn
949
-
950
  def forward(self, x: torch.Tensor) -> torch.Tensor:
951
  moe_top_k = getattr(self.router, "top_k", 4)
952
  moe_num_experts = getattr(self.experts, "num_experts", 128)
@@ -954,7 +1168,9 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
954
  alpha = getattr(self.experts, "alpha", 1.0)
955
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
958
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
 
960
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
@@ -962,15 +1178,22 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
962
  device_mesh = get_device_mesh(self)
963
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
 
965
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
 
 
 
 
966
  forward_fn = parallel_forward_once if has_parallel else forward_once
967
-
968
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
969
  mlp_impl = getattr(self, "mlp_impl", "grouped")
970
-
971
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
  x=x,
973
  router_weight=self.router.weight,
 
974
  moe_top_k=moe_top_k,
975
  moe_num_experts=moe_num_experts,
976
  moe_jitter_eps=moe_jitter_eps,
@@ -998,4 +1221,4 @@ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
998
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
  shared_activation_fn=self.shared_activation_fn,
1000
  )
1001
- return output, expert_weights_out
 
1
  import torch
2
  import torch.distributed as dist
3
 
4
+ from typing import Optional, Any, TYPE_CHECKING
5
 
6
  from . import _layers
7
  from . import ops
8
 
9
+ # Conditional import for meta kernel registration
10
+ if TYPE_CHECKING:
11
+
12
+ def register_fake(fn):
13
+ return lambda name: fn
14
+
15
+ else:
16
+ try:
17
+ from torch.library import register_fake
18
+ except ImportError:
19
+ try:
20
+ from torch.library import impl_abstract as register_fake
21
+ except ImportError:
22
+ # Fallback for older PyTorch versions
23
+ def register_fake(op_name):
24
+ def decorator(fn):
25
+ return fn
26
+
27
+ return decorator
28
+
29
+
30
+ # Meta kernel implementations for torch.compile compatibility
31
+ def _install_meta_kernels():
32
+ """Install meta kernels for existing MegaBlocks operations"""
33
+
34
+ # Create wrapper functions that check for compilation and return meta tensors
35
+
36
+ # Patch ops.sort
37
+ if hasattr(ops, "sort"):
38
+ original_sort = ops.sort
39
+
40
+ def sort_with_meta(x, end_bit=None):
41
+ if torch.compiler.is_compiling():
42
+ print("Using meta kernel for sort")
43
+ # Meta implementation - return tensors with correct shape/dtype/device
44
+ return torch.empty_like(x), torch.empty_like(x)
45
+ # print("Using original sort kernel")
46
+ return original_sort(x, end_bit)
47
+
48
+ ops.sort = sort_with_meta
49
+
50
+ # Patch ops.histogram
51
+ if hasattr(ops, "histogram"):
52
+ original_histogram = ops.histogram
53
+
54
+ def histogram_with_meta(x, max_val):
55
+ if torch.compiler.is_compiling():
56
+ # Meta implementation
57
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
58
+ return original_histogram(x, max_val)
59
+
60
+ ops.histogram = histogram_with_meta
61
+
62
+ # Patch ops.inclusive_cumsum
63
+ if hasattr(ops, "inclusive_cumsum"):
64
+ original_inclusive_cumsum = ops.inclusive_cumsum
65
+
66
+ def inclusive_cumsum_with_meta(x, dim):
67
+ if torch.compiler.is_compiling():
68
+ # Meta implementation
69
+ return torch.empty_like(x)
70
+ return original_inclusive_cumsum(x, dim)
71
+
72
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
73
+
74
+ # Patch ops.binned_gather
75
+ if hasattr(ops, "binned_gather"):
76
+ original_binned_gather = ops.binned_gather
77
+
78
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
79
+ if torch.compiler.is_compiling():
80
+ # Meta implementation - output shape based on bin_size
81
+ if x.dim() >= 2:
82
+ hidden_size = x.size(-1)
83
+ return torch.empty(
84
+ (bin_size, x.size(1), hidden_size),
85
+ dtype=x.dtype,
86
+ device=x.device,
87
+ )
88
+ else:
89
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
90
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
91
+
92
+ ops.binned_gather = binned_gather_with_meta
93
+
94
+ # Patch ops.binned_scatter
95
+ if hasattr(ops, "binned_scatter"):
96
+ original_binned_scatter = ops.binned_scatter
97
+
98
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
99
+ if torch.compiler.is_compiling():
100
+ # Meta implementation - typically reduces to 2D
101
+ if x.dim() >= 3:
102
+ return torch.empty(
103
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
104
+ )
105
+ else:
106
+ return torch.empty_like(x)
107
+ return original_binned_scatter(x, indices, weights, bins, top_k)
108
+
109
+ ops.binned_scatter = binned_scatter_with_meta
110
+
111
+ # Patch ops.gather
112
+ if hasattr(ops, "gather"):
113
+ original_gather = ops.gather
114
+
115
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
116
+ if torch.compiler.is_compiling():
117
+ # Meta implementation
118
+ if x.dim() >= 2:
119
+ hidden_size = x.size(-1)
120
+ return torch.empty(
121
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
122
+ )
123
+ else:
124
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
125
+ return original_gather(x, indices, bin_ids, bins, top_k)
126
+
127
+ ops.gather = gather_with_meta
128
+
129
+ # Patch ops.scatter
130
+ if hasattr(ops, "scatter"):
131
+ original_scatter = ops.scatter
132
+
133
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
134
+ if torch.compiler.is_compiling():
135
+ # Meta implementation - restore sequence shape
136
+ seq_len = (
137
+ indices.size(0) // top_k
138
+ if indices.numel() > 0 and top_k > 0
139
+ else x.size(0)
140
+ )
141
+ if x.dim() >= 2:
142
+ return torch.empty(
143
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
144
+ )
145
+ else:
146
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
147
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
148
+
149
+ ops.scatter = scatter_with_meta
150
+
151
+ # Patch ops.replicate
152
+ if hasattr(ops, "replicate"):
153
+ original_replicate = ops.replicate
154
+
155
+ def replicate_with_meta(x, bins, num_outputs):
156
+ if torch.compiler.is_compiling():
157
+ # Meta implementation
158
+ return torch.empty(
159
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
160
+ )
161
+ return original_replicate(x, bins, num_outputs)
162
+
163
+ ops.replicate = replicate_with_meta
164
+
165
+ # Patch ops.repeat (if it's a regular function)
166
+ if hasattr(ops, "repeat"):
167
+ original_repeat = ops.repeat
168
+
169
+ def repeat_with_meta(x, repeats):
170
+ if torch.compiler.is_compiling():
171
+ # Meta implementation
172
+ if isinstance(repeats, (tuple, list)):
173
+ new_shape = list(x.shape)
174
+ for i, rep in enumerate(repeats):
175
+ if i < len(new_shape):
176
+ new_shape[i] *= rep
177
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
178
+ else:
179
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
180
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
181
+ return original_repeat(x, repeats)
182
+
183
+ ops.repeat = repeat_with_meta
184
+
185
+
186
+ # Install meta kernels on import
187
+ try:
188
+ _install_meta_kernels()
189
+ except Exception as e:
190
+ # If meta kernel installation fails, continue without them
191
+ # torch.compile may not work but the library will still function
192
+ import warnings
193
+
194
+ warnings.warn(
195
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
196
+ )
197
+
198
 
199
  # Set the expert model parallel attributes on a tensor
200
  def set_expert_model_parallel_attributes(
 
269
  def route_tokens(
270
  x: torch.Tensor,
271
  router_weight: torch.Tensor,
272
+ router_bias: torch.Tensor,
273
  moe_top_k: int,
274
  moe_num_experts: int,
275
  moe_jitter_eps: float = None,
 
281
  x = apply_jitter(x, moe_jitter_eps)
282
 
283
  x_flat = x.view(-1, x.shape[-1])
284
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
285
  expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
286
  expert_weights = expert_weights.softmax(dim=-1)
287
  if moe_normalize_expert_weights is not None:
 
319
  w2_bias: torch.Tensor,
320
  gradient_scale: Optional[float] = None,
321
  alpha: float = 1.702,
322
+ limit: float = 7.0,
323
  ):
324
  # Scale weights
325
  w1 = scale_grad(w1, gradient_scale)
 
335
 
336
  # Forward pass
337
  gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
338
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
339
+ gate = gate.clamp(min=None, max=limit)
340
+ up = up.clamp(min=-limit, max=limit)
341
  glu = gate * torch.sigmoid(gate * alpha)
342
+ next_states = torch.bmm(((up + 1) * glu), w2)
343
+ next_states += w2_bias[..., None, :]
344
+ return next_states
 
345
 
346
  # Shared expert MLP forward pass
347
  def shared_mlp_forward(
 
375
 
376
  # Up projection
377
  x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
378
+
379
  # Activation
380
  x = activation_fn(x)
381
+
382
  # Down projection
383
  x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
384
+
385
  return x
386
 
387
 
 
848
  def moe_forward(
849
  x: torch.Tensor,
850
  router_weight: torch.Tensor,
851
+ router_bias: Optional[torch.Tensor],
852
  moe_top_k: int,
853
  moe_num_experts: int,
854
  moe_jitter_eps: float = None,
 
874
  logits, expert_weights, expert_indices = route_tokens(
875
  x,
876
  router_weight,
877
+ router_bias,
878
  moe_top_k,
879
  moe_num_experts,
880
  moe_jitter_eps,
 
936
  def moe_forward_with_shared_expert(
937
  x: torch.Tensor,
938
  router_weight: torch.Tensor,
939
+ router_bias: Optional[torch.Tensor],
940
  moe_top_k: int,
941
  moe_num_experts: int,
942
  moe_jitter_eps: float = None,
 
969
  expert_out, expert_weights, router_scores = moe_forward(
970
  x=x,
971
  router_weight=router_weight,
972
+ router_bias=router_bias,
973
  moe_top_k=moe_top_k,
974
  moe_num_experts=moe_num_experts,
975
  moe_jitter_eps=moe_jitter_eps,
 
990
  hidden_size=hidden_size,
991
  mlp_impl=mlp_impl,
992
  )
993
+
994
  # If shared expert weights provided, compute shared expert output
995
  if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
996
  shared_expert_out = shared_mlp_forward(
 
1002
  activation_fn=shared_activation_fn,
1003
  gradient_scale=gradient_scale,
1004
  )
1005
+
1006
  # Combine expert outputs
1007
  combined_out = combine_expert_shared_outputs(
1008
  shared_expert_out=shared_expert_out,
 
1010
  shared_expert_weighted_sum=shared_expert_weighted_sum,
1011
  moe_top_k=moe_top_k,
1012
  )
1013
+
1014
  return combined_out, expert_weights, router_scores
1015
+
1016
  # Return regular MoE output if no shared expert
1017
  return expert_out, expert_weights, router_scores
1018
 
 
1028
 
1029
  if output_layer_init_method is None:
1030
  output_layer_init_method = init_method
1031
+
1032
  # Create weight tensors
1033
  up_proj_weight = torch.empty(
1034
  shared_expert_hidden_size,
 
1042
  device=device,
1043
  dtype=dtype,
1044
  )
1045
+
1046
  # Initialize weights
1047
  init_method(up_proj_weight)
1048
  output_layer_init_method(down_proj_weight)
1049
+
1050
  # No bias by default
1051
  return up_proj_weight, down_proj_weight, None, None
1052
 
1053
+
1054
  # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
1055
  # This exists because device_mesh is trapped in hook closures with no model attribute
1056
  # Fragile - breaks if hook structure changes or Python internals change
 
1059
  # Extract device_mesh from child's unused pre_hook closure
1060
  try:
1061
  # Find the pre-hook that contains 'device_mesh' in its closure
1062
+ hook = next(
1063
+ h
1064
+ for h in model.experts._forward_pre_hooks.values()
1065
+ if "device_mesh" in h.__code__.co_freevars
1066
+ )
1067
  # Extract the device_mesh from the closure
1068
+ return hook.__closure__[
1069
+ hook.__code__.co_freevars.index("device_mesh")
1070
+ ].cell_contents
1071
  except Exception:
1072
  return None
1073
 
 
1081
  alpha = getattr(self.experts, "alpha", 1.0)
1082
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1083
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1084
+ moe_normalize_expert_weights = getattr(
1085
+ self.experts, "normalize_expert_weights", None
1086
+ )
1087
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1088
 
1089
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1091
  device_mesh = get_device_mesh(self)
1092
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1093
 
1094
+ has_parallel = (
1095
+ expert_parallel_group is not None
1096
+ and dist.is_initialized()
1097
+ and dist.get_world_size(expert_parallel_group) > 1
1098
+ )
1099
  forward_fn = parallel_forward_once if has_parallel else forward_once
1100
+
1101
+ sort_end_bit = max(
1102
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1103
+ )
1104
  mlp_impl = getattr(self, "mlp_impl", "grouped")
 
1105
  output, expert_weights_out, *_ = moe_forward(
1106
  x=x,
1107
  router_weight=self.router.weight,
1108
+ router_bias=self.router.bias,
1109
  moe_top_k=moe_top_k,
1110
  moe_num_experts=moe_num_experts,
1111
  moe_jitter_eps=moe_jitter_eps,
 
1129
  return output, expert_weights_out
1130
 
1131
 
1132
+ # Export main classes
1133
+ __all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
1134
+
1135
+
1136
  class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
1137
+
1138
  def __init__(self):
1139
  super().__init__()
1140
  # Shared expert weights will be set by the user
 
1144
  self.shared_down_proj_bias = None
1145
  self.shared_expert_weighted_sum = False
1146
  self.shared_activation_fn = None
1147
+
1148
  def set_shared_expert_weights(
1149
  self,
1150
  up_proj_weight: torch.Tensor,
 
1160
  self.shared_down_proj_bias = down_proj_bias
1161
  self.shared_expert_weighted_sum = weighted_sum
1162
  self.shared_activation_fn = activation_fn
1163
+
1164
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1165
  moe_top_k = getattr(self.router, "top_k", 4)
1166
  moe_num_experts = getattr(self.experts, "num_experts", 128)
 
1168
  alpha = getattr(self.experts, "alpha", 1.0)
1169
  moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1170
  moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1171
+ moe_normalize_expert_weights = getattr(
1172
+ self.experts, "normalize_expert_weights", None
1173
+ )
1174
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1175
 
1176
  expert_parallel_group = getattr(self, "expert_parallel_group", None)
 
1178
  device_mesh = get_device_mesh(self)
1179
  expert_parallel_group = device_mesh.get_group() if device_mesh else None
1180
 
1181
+ has_parallel = (
1182
+ expert_parallel_group is not None
1183
+ and dist.is_initialized()
1184
+ and dist.get_world_size(expert_parallel_group) > 1
1185
+ )
1186
  forward_fn = parallel_forward_once if has_parallel else forward_once
1187
+
1188
+ sort_end_bit = max(
1189
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1190
+ )
1191
  mlp_impl = getattr(self, "mlp_impl", "grouped")
1192
+
1193
  output, expert_weights_out, *_ = moe_forward_with_shared_expert(
1194
  x=x,
1195
  router_weight=self.router.weight,
1196
+ router_bias=self.router.bias,
1197
  moe_top_k=moe_top_k,
1198
  moe_num_experts=moe_num_experts,
1199
  moe_jitter_eps=moe_jitter_eps,
 
1221
  shared_expert_weighted_sum=self.shared_expert_weighted_sum,
1222
  shared_activation_fn=self.shared_activation_fn,
1223
  )
1224
+ return output, expert_weights_out