drbh
commited on
Commit
·
8176cbe
1
Parent(s):
cd5b9c4
fix: include torch compile flag
Browse files
torch-ext/megablocks/layers.py
CHANGED
@@ -1073,6 +1073,7 @@ def get_device_mesh(model):
|
|
1073 |
|
1074 |
|
1075 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
|
|
1076 |
|
1077 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1078 |
moe_top_k = getattr(self.router, "top_k", 4)
|
|
|
1073 |
|
1074 |
|
1075 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
1076 |
+
can_torch_compile: bool = True
|
1077 |
|
1078 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1079 |
moe_top_k = getattr(self.router, "top_k", 4)
|