drbh commited on
Commit
8176cbe
·
1 Parent(s): cd5b9c4

fix: include torch compile flag

Browse files
Files changed (1) hide show
  1. torch-ext/megablocks/layers.py +1 -0
torch-ext/megablocks/layers.py CHANGED
@@ -1073,6 +1073,7 @@ def get_device_mesh(model):
1073
 
1074
 
1075
  class MegaBlocksMoeMLP(torch.nn.Module):
 
1076
 
1077
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1078
  moe_top_k = getattr(self.router, "top_k", 4)
 
1073
 
1074
 
1075
  class MegaBlocksMoeMLP(torch.nn.Module):
1076
+ can_torch_compile: bool = True
1077
 
1078
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1079
  moe_top_k = getattr(self.router, "top_k", 4)