autoprogrammer commited on
Commit
96ca21e
·
verified ·
1 Parent(s): d37cdb2

Update modeling_densebackward_olmoe0125.py

Browse files
modeling_densebackward_olmoe0125.py CHANGED
@@ -43,8 +43,8 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
43
 
44
  # 3) build hard & ste masks
45
  mask_hard = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype) # (N, num_experts)
46
- mask_ste = mask_hard + (routing_weights - routing_weights.detach())
47
- # mask_ste = mask_hard + (router_logits - router_logits.detach())
48
 
49
  # 4) compute gated weights = π * mask, then optionally renormalize
50
  gated = routing_weights * mask_ste # zero-out non-TopK
 
43
 
44
  # 3) build hard & ste masks
45
  mask_hard = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype) # (N, num_experts)
46
+ #mask_ste = mask_hard + (routing_weights - routing_weights.detach())
47
+ mask_ste = mask_hard + (router_logits - router_logits.detach())
48
 
49
  # 4) compute gated weights = π * mask, then optionally renormalize
50
  gated = routing_weights * mask_ste # zero-out non-TopK