Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
@@ -43,8 +43,8 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
43 |
|
44 |
# 3) build hard & ste masks
|
45 |
mask_hard = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype) # (N, num_experts)
|
46 |
-
mask_ste = mask_hard + (routing_weights - routing_weights.detach())
|
47 |
-
|
48 |
|
49 |
# 4) compute gated weights = π * mask, then optionally renormalize
|
50 |
gated = routing_weights * mask_ste # zero-out non-TopK
|
|
|
43 |
|
44 |
# 3) build hard & ste masks
|
45 |
mask_hard = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype) # (N, num_experts)
|
46 |
+
#mask_ste = mask_hard + (routing_weights - routing_weights.detach())
|
47 |
+
mask_ste = mask_hard + (router_logits - router_logits.detach())
|
48 |
|
49 |
# 4) compute gated weights = π * mask, then optionally renormalize
|
50 |
gated = routing_weights * mask_ste # zero-out non-TopK
|