autoprogrammer commited on
Commit
827f943
·
verified ·
1 Parent(s): 96ca21e

Update modeling_densebackward_olmoe0125.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe0125.py +123 -34
modeling_densebackward_olmoe0125.py CHANGED
@@ -25,7 +25,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
25
  """
26
  def forward(self, hidden_states: torch.Tensor):
27
  """
28
- Gate version of implementation of straight-through, π -> mask, dmask / dπ = 1
29
  """
30
  batch_size, seq_length, hidden_dim = hidden_states.shape
31
  dtype = hidden_states.dtype
@@ -34,49 +34,73 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
34
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
35
  N_tokens = flat_hidden.size(0)
36
 
37
- # 1) router & softmax
38
- router_logits = self.gate(flat_hidden).to(dtype=dtype) # (N, num_experts)
39
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (N, num_experts)
 
40
 
41
- # 2) top-K selection
42
- _, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # (N, K), (N, K)
43
-
44
- # 3) build hard & ste masks
45
- mask_hard = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype) # (N, num_experts)
46
- #mask_ste = mask_hard + (routing_weights - routing_weights.detach())
47
- mask_ste = mask_hard + (router_logits - router_logits.detach())
48
-
49
- # 4) compute gated weights = π * mask, then optionally renormalize
50
- gated = routing_weights * mask_ste # zero-out non-TopK
51
  if self.norm_topk_prob:
52
- norm_ratio = gated.sum(dim=-1, keepdim=True) # (N,1)
53
- gated = gated / norm_ratio # normalized TopK
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # 5)prepare accumulators
 
 
 
 
 
 
56
  dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
57
  sparse_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
58
 
59
- for expert_idx, expert_layer in enumerate(self.experts):
60
- expert_output = expert_layer(flat_hidden).to(dtype=dtype) # (N_tokens, hidden_dim)
61
- activation_mask = (selected_experts == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
62
 
 
 
 
 
 
63
  if expert_output.requires_grad:
64
  expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
65
-
66
- # a) Dense-STE backward uses gated weights
67
- weights = gated[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
68
- dense_outputs += expert_output * weights
69
-
70
- # b) Sparse forward -- find tokens where this expert is among top_k (active experts)
71
- active = (selected_experts == expert_idx)
72
- if active.any():
73
- token_indices, _ = torch.where(active)
74
- weights_topk = gated[token_indices, expert_idx].unsqueeze(-1) # (num_matches,1)
75
- sparse_outputs[token_indices] += expert_output[token_indices] * weights_topk
76
-
77
- # 6) STE mix: forward from sparse, backward from dense
 
 
 
 
 
78
  final_flat = sparse_outputs.detach() + (dense_outputs - dense_outputs.detach())
79
- final_output = final_flat.view(batch_size, seq_length, hidden_dim).to(dtype=dtype)
 
80
 
81
  return final_output, router_logits
82
 
@@ -85,6 +109,71 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
85
 
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  # def forward(self, hidden_states: torch.Tensor):
89
  # batch_size, seq_length, hidden_dim = hidden_states.shape
90
  # # 记录输入张量的数据类型,确保所有计算保持一致
 
25
  """
26
  def forward(self, hidden_states: torch.Tensor):
27
  """
28
+ forward_partscale_fixep_norm_dtch_tau
29
  """
30
  batch_size, seq_length, hidden_dim = hidden_states.shape
31
  dtype = hidden_states.dtype
 
34
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
35
  N_tokens = flat_hidden.size(0)
36
 
37
+ # Compute routing logic
38
+ router_logits = self.gate(flat_hidden).to(dtype=dtype) # (B*L, num_experts)
39
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*L, num_experts)
40
+ routing_weights_tau = F.softmax(router_logits / 1.1, dim=1, dtype=torch.float) # (B*L, num_experts)
41
 
42
+ # Select top-k experts
43
+ routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
44
+ routing_weights_topk_tau, selected_experts_tau = torch.topk(routing_weights_tau, self.top_k, dim=-1)
 
 
 
 
 
 
 
45
  if self.norm_topk_prob:
46
+ norm_ratio = routing_weights_topk.sum(dim=-1, keepdim=True)
47
+ # Normalize top-k routing weights
48
+ routing_weights_topk = routing_weights_topk / norm_ratio
49
+ # Only scale the selected top-k positions in routing_weights
50
+ mask = F.one_hot(selected_experts_tau, num_classes=self.num_experts).sum(dim=1).to(dtype)
51
+ routing_weights_topk_tau = routing_weights_tau * mask
52
+ norm_ratio_dense = routing_weights_topk_tau.sum(dim=-1, keepdim=True)
53
+ # ------------------------------------Choose Section-----------------------------------------------
54
+ # current --> partscale_fix_expert implementation
55
+ routing_weights_tau = routing_weights_tau * (1.0 - mask) / norm_ratio_dense.detach() + routing_weights_topk_tau / norm_ratio_dense
56
+ routing_weights = routing_weights * (1.0 - mask) / norm_ratio.detach() + routing_weights * mask / norm_ratio
57
+
58
+ # should be --> the gated implemenation, by comment out the line above and uncomment the two lines below
59
+ # gated = routing_weights.detach() * mask + (routing_weights - routing_weights.detach())
60
+ # routing_weights = gated / gated.sum(dim=-1, keepdim=True)
61
+ # ------------------------------------Choose Section-----------------------------------------------
62
 
63
+ routing_weights_topk = routing_weights_topk.to(dtype=dtype)
64
+
65
+ # Convert full routing_weights to consistent dtype for dense accumulation
66
+ # routing_weights = routing_weights.to(dtype=dtype)
67
+ routing_weights_tau = routing_weights_tau.to(dtype=dtype)
68
+
69
+ # Prepare accumulators: one for dense_outputs, one for sparse_outputs
70
  dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
71
  sparse_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
72
 
73
+ # For mapping top-k positions when accumulating sparse_outputs
74
+ # selected_experts: (N_tokens, top_k)
 
75
 
76
+ for expert_idx in range(self.num_experts):
77
+ expert_layer = self.experts[expert_idx]
78
+ # Compute current expert output for all tokens
79
+ expert_output = expert_layer(flat_hidden).to(dtype=dtype) # (N_tokens, hidden_dim)
80
+ activation_mask = (selected_experts_tau == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
81
  if expert_output.requires_grad:
82
  expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
83
+ expert_output = expert_output.to(dtype=dtype)
84
+ # Dense accumulation: multiply by full routing weight and add
85
+ weight_full_tau = routing_weights_tau[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
86
+ weight_full = routing_weights[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
87
+ dense_outputs = dense_outputs + expert_output * (weight_full_tau-weight_full_tau.detach()) + expert_output * weight_full.detach()
88
+
89
+ # Sparse accumulation: find tokens where this expert is among top_k
90
+ # matches: Boolean mask where selected_experts == expert_idx → shape (N_tokens, top_k)
91
+ matches = (selected_experts == expert_idx)
92
+ if matches.any():
93
+ # locations: tuple of (token_indices, k_indices)
94
+ token_indices, k_indices = torch.where(matches)
95
+ # corresponding top-k weights
96
+ weights_topk = routing_weights_topk[token_indices, k_indices].unsqueeze(-1) # (num_matches, 1)
97
+ # Accumulate sparse_outputs only for matched tokens
98
+ sparse_outputs[token_indices] = sparse_outputs[token_indices] + expert_output[token_indices] * weights_topk
99
+
100
+ # Combine sparse forward output and dense backward output
101
  final_flat = sparse_outputs.detach() + (dense_outputs - dense_outputs.detach())
102
+ final_flat = final_flat.to(dtype=dtype)
103
+ final_output = final_flat.view(batch_size, seq_length, hidden_dim)
104
 
105
  return final_output, router_logits
106
 
 
109
 
110
 
111
 
112
+
113
+
114
+
115
+ # def forward(self, hidden_states: torch.Tensor):
116
+ # """
117
+ # Gate version of implementation of straight-through, π -> mask, dmask / dπ = 1
118
+ # """
119
+ # batch_size, seq_length, hidden_dim = hidden_states.shape
120
+ # dtype = hidden_states.dtype
121
+ # device = hidden_states.device
122
+
123
+ # flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
124
+ # N_tokens = flat_hidden.size(0)
125
+
126
+ # # 1) router & softmax
127
+ # router_logits = self.gate(flat_hidden).to(dtype=dtype) # (N, num_experts)
128
+ # routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (N, num_experts)
129
+
130
+ # # 2) top-K selection
131
+ # _, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # (N, K), (N, K)
132
+
133
+ # # 3) build hard & ste masks
134
+ # mask_hard = F.one_hot(selected_experts, num_classes=self.num_experts).sum(dim=1).to(dtype) # (N, num_experts)
135
+ # #mask_ste = mask_hard + (routing_weights - routing_weights.detach())
136
+ # mask_ste = mask_hard + (router_logits - router_logits.detach())
137
+
138
+ # # 4) compute gated weights = π * mask, then optionally renormalize
139
+ # gated = routing_weights * mask_ste # zero-out non-TopK
140
+ # if self.norm_topk_prob:
141
+ # norm_ratio = gated.sum(dim=-1, keepdim=True) # (N,1)
142
+ # gated = gated / norm_ratio # normalized TopK
143
+
144
+ # # 5)prepare accumulators
145
+ # dense_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
146
+ # sparse_outputs = torch.zeros((N_tokens, hidden_dim), dtype=dtype, device=device)
147
+
148
+ # for expert_idx, expert_layer in enumerate(self.experts):
149
+ # expert_output = expert_layer(flat_hidden).to(dtype=dtype) # (N_tokens, hidden_dim)
150
+ # activation_mask = (selected_experts == expert_idx).any(dim=1).float().unsqueeze(-1).to(dtype)
151
+
152
+ # if expert_output.requires_grad:
153
+ # expert_output.register_hook(lambda grad, mask=activation_mask: grad * mask)
154
+
155
+ # # a) Dense-STE backward uses gated weights
156
+ # weights = gated[:, expert_idx].unsqueeze(-1) # (N_tokens, 1)
157
+ # dense_outputs += expert_output * weights
158
+
159
+ # # b) Sparse forward -- find tokens where this expert is among top_k (active experts)
160
+ # active = (selected_experts == expert_idx)
161
+ # if active.any():
162
+ # token_indices, _ = torch.where(active)
163
+ # weights_topk = gated[token_indices, expert_idx].unsqueeze(-1) # (num_matches,1)
164
+ # sparse_outputs[token_indices] += expert_output[token_indices] * weights_topk
165
+
166
+ # # 6) STE mix: forward from sparse, backward from dense
167
+ # final_flat = sparse_outputs.detach() + (dense_outputs - dense_outputs.detach())
168
+ # final_output = final_flat.view(batch_size, seq_length, hidden_dim).to(dtype=dtype)
169
+
170
+ # return final_output, router_logits
171
+
172
+
173
+
174
+
175
+
176
+
177
  # def forward(self, hidden_states: torch.Tensor):
178
  # batch_size, seq_length, hidden_dim = hidden_states.shape
179
  # # 记录输入张量的数据类型,确保所有计算保持一致