Upload moe_idefics2.py
Browse files- moe_idefics2.py +430 -0
moe_idefics2.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoProcessor, Idefics2ForConditionalGeneration
|
2 |
+
from transformers import Cache
|
3 |
+
from transformers import Idefics2Model
|
4 |
+
from transformers.models.idefics2.modeling_idefics2 import Idefics2MLP
|
5 |
+
from transformers.models.idefics2.modeling_idefics2 import Idefics2Config, Idefics2PreTrainedModel
|
6 |
+
|
7 |
+
from tqdm.notebook import tqdm
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import copy
|
11 |
+
|
12 |
+
from torch.optim import Adam
|
13 |
+
from typing import Optional, Tuple
|
14 |
+
|
15 |
+
from transformers import (
|
16 |
+
PreTrainedModel,
|
17 |
+
AutoConfig,
|
18 |
+
)
|
19 |
+
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
import matplotlib.pyplot as plt
|
23 |
+
|
24 |
+
# Define the Gating Layer
|
25 |
+
class GatingLayer(nn.Module):
|
26 |
+
def __init__(self, input_dim, num_experts, k, layer_dtype=torch.float16):
|
27 |
+
super(GatingLayer, self).__init__()
|
28 |
+
self.num_experts = num_experts
|
29 |
+
self.k = k
|
30 |
+
self.gate = nn.Linear(input_dim, num_experts).to(dtype=layer_dtype)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
gate_scores = torch.softmax(self.gate(x), dim=-1)
|
34 |
+
topk_values, topk_indices = torch.topk(gate_scores, self.k, dim=-1)
|
35 |
+
topk_values = F.softmax(topk_values, dim=-1)
|
36 |
+
|
37 |
+
return topk_values, topk_indices
|
38 |
+
|
39 |
+
class MoE(nn.Module):
|
40 |
+
def __init__(self, input_dim, experts, gating_layer, config):
|
41 |
+
super(MoE, self).__init__()
|
42 |
+
self.experts = nn.ModuleList(experts)
|
43 |
+
self.gating_layer = gating_layer
|
44 |
+
self.output_dim = config.text_config.hidden_size
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
48 |
+
gate_values, gate_indices = self.gating_layer(x)
|
49 |
+
batch_size, seq_length, _ = x.size()
|
50 |
+
|
51 |
+
# Stack all expert parameters for efficient processing
|
52 |
+
expert_outputs = []
|
53 |
+
for expert in self.experts:
|
54 |
+
expert_outputs.append(expert.down_proj(expert.act_fn(expert.gate_proj(x)) * expert.up_proj(x)))
|
55 |
+
'''
|
56 |
+
|
57 |
+
up_states = expert.gate_up_proj(x.view(-1, x.size(-1))) # Flatten to [batch_size * seq_length, input_dim]
|
58 |
+
gate, up_states = up_states.chunk(2, dim=-1)
|
59 |
+
up_states = up_states * expert.activation_fn(gate)
|
60 |
+
expert_output = expert.down_proj(up_states)
|
61 |
+
expert_outputs.append(expert_output.view(batch_size, seq_length, -1))
|
62 |
+
'''
|
63 |
+
|
64 |
+
expert_outputs = torch.stack(expert_outputs, dim=-1) # Shape: [batch_size, seq_length, hidden_size, num_experts]
|
65 |
+
|
66 |
+
# Use torch.gather to select the expert outputs based on gate_indices
|
67 |
+
expanded_gate_indices = gate_indices.unsqueeze(-2).expand(-1, -1, self.output_dim, -1) # Shape: [batch_size, seq_length, hidden_size, k]
|
68 |
+
selected_expert_outputs = torch.gather(expert_outputs, -1, expanded_gate_indices) # Shape: [batch_size, seq_length, hidden_size, k]
|
69 |
+
|
70 |
+
# Weight the selected expert outputs by gate values
|
71 |
+
gate_values = gate_values.unsqueeze(-2) # Shape: [batch_size, seq_length, 1, k]
|
72 |
+
weighted_expert_outputs = selected_expert_outputs * gate_values # Shape: [batch_size, seq_length, hidden_size, k]
|
73 |
+
|
74 |
+
# Sum the weighted expert outputs across the k dimension
|
75 |
+
moe_output = weighted_expert_outputs.sum(dim=-1) # Shape: [batch_size, seq_length, hidden_size]
|
76 |
+
|
77 |
+
return moe_output.to(self.gating_layer.gate.weight.dtype)
|
78 |
+
|
79 |
+
class ModifiedIdefics2DecoderLayer(nn.Module):
|
80 |
+
|
81 |
+
def __init__(self, original_layer, moe_layer):
|
82 |
+
super(ModifiedIdefics2DecoderLayer, self).__init__()
|
83 |
+
self.self_attn = original_layer.self_attn
|
84 |
+
self.mlp = moe_layer
|
85 |
+
self.input_layernorm = original_layer.input_layernorm
|
86 |
+
self.post_attention_layernorm = original_layer.post_attention_layernorm
|
87 |
+
#print ("Init: ModifiedIdefics2DecoderLayer")
|
88 |
+
|
89 |
+
def forward(
|
90 |
+
self,
|
91 |
+
hidden_states: torch.Tensor,
|
92 |
+
attention_mask: Optional[torch.Tensor] = None,
|
93 |
+
position_ids: Optional[torch.LongTensor] = None,
|
94 |
+
past_key_value: Optional[Cache] = None,
|
95 |
+
output_attentions: Optional[bool] = False,
|
96 |
+
use_cache: Optional[bool] = False,
|
97 |
+
cache_position: Optional[torch.LongTensor] = None,
|
98 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
102 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
103 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
104 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
105 |
+
output_attentions (`bool`, *optional*):
|
106 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
107 |
+
returned tensors for more detail.
|
108 |
+
use_cache (`bool`, *optional*):
|
109 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
110 |
+
(see `past_key_values`).
|
111 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
112 |
+
"""
|
113 |
+
|
114 |
+
residual = hidden_states
|
115 |
+
|
116 |
+
hidden_states = self.input_layernorm(hidden_states)
|
117 |
+
|
118 |
+
# Self Attention
|
119 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
120 |
+
hidden_states=hidden_states,
|
121 |
+
attention_mask=attention_mask,
|
122 |
+
position_ids=position_ids,
|
123 |
+
past_key_value=past_key_value,
|
124 |
+
output_attentions=output_attentions,
|
125 |
+
use_cache=use_cache,
|
126 |
+
### MJB TODO: Need to check
|
127 |
+
#cache_position=cache_position,
|
128 |
+
)
|
129 |
+
hidden_states = residual + hidden_states
|
130 |
+
|
131 |
+
# Fully Connected
|
132 |
+
residual = hidden_states
|
133 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
134 |
+
hidden_states = self.mlp(hidden_states)
|
135 |
+
hidden_states = residual + hidden_states
|
136 |
+
|
137 |
+
outputs = (hidden_states,)
|
138 |
+
|
139 |
+
if output_attentions:
|
140 |
+
outputs += (self_attn_weights,)
|
141 |
+
|
142 |
+
if use_cache:
|
143 |
+
outputs += (present_key_value,)
|
144 |
+
|
145 |
+
return outputs
|
146 |
+
|
147 |
+
#Define Idefics2ForCausalLMMoEConfig
|
148 |
+
class Idefics2ForCausalLMMoEConfig(Idefics2Config):
|
149 |
+
model_type = "idefics2_moe"
|
150 |
+
|
151 |
+
def __init__(self, config=None, k=1, num_expert_models=2, use_embeddings_in_router=False, **kwargs):
|
152 |
+
if config is not None:
|
153 |
+
kwargs.update(config.to_dict())
|
154 |
+
super().__init__(**kwargs)
|
155 |
+
self.k = k
|
156 |
+
self.num_expert_models = num_expert_models
|
157 |
+
self.architectures = "Idefics2ForCausalLMMoE"
|
158 |
+
self.auto_map = {
|
159 |
+
"AutoConfig": "moe_idefics2.Idefics2ForCausalLMMoEConfig",
|
160 |
+
"AutoModelForCausalLM": "moe_idefics2.Idefics2ForCausalLMMoE",
|
161 |
+
}
|
162 |
+
self.use_embeddings_in_router=use_embeddings_in_router
|
163 |
+
|
164 |
+
|
165 |
+
#Define MoE Model
|
166 |
+
class Idefics2ForCausalLMMoE(Idefics2ForConditionalGeneration):
|
167 |
+
|
168 |
+
config_class = Idefics2ForCausalLMMoEConfig
|
169 |
+
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
config,
|
173 |
+
base_model=None,
|
174 |
+
expert_models=None,
|
175 |
+
layer_dtype=torch.bfloat16,
|
176 |
+
**kwargs,
|
177 |
+
):
|
178 |
+
super().__init__(config)
|
179 |
+
|
180 |
+
self.layer_dtype = layer_dtype
|
181 |
+
self.custom_device = torch.device(
|
182 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
183 |
+
)
|
184 |
+
k = self.config.k
|
185 |
+
|
186 |
+
self.num_layers = len(base_model.model.text_model.layers) if base_model else 0
|
187 |
+
|
188 |
+
self.config.auto_map = {
|
189 |
+
"AutoConfig": "moe_idefics2.Idefics2ForCausalLMMoEConfig",
|
190 |
+
"AutoModelForCausalLM": "moe_idefics2.Idefics2ForCausalLMMoE",
|
191 |
+
}
|
192 |
+
|
193 |
+
self.use_embeddings_in_router=config.use_embeddings_in_router
|
194 |
+
print ("Use embeddigs in router: ", self.use_embeddings_in_router )
|
195 |
+
|
196 |
+
self.model = base_model or Idefics2ForConditionalGeneration(
|
197 |
+
self.config
|
198 |
+
)
|
199 |
+
|
200 |
+
if base_model and expert_models:
|
201 |
+
self.num_expert_models = len(expert_models)
|
202 |
+
self._init_moe_layers(base_model, expert_models, k, layer_dtype)
|
203 |
+
print ("CONSTRUCTOR self.model",self.model)
|
204 |
+
else:
|
205 |
+
|
206 |
+
print(
|
207 |
+
"Init function called and generating dummy experts: k=",
|
208 |
+
k,
|
209 |
+
"experts=",
|
210 |
+
self.config.num_expert_models,
|
211 |
+
)
|
212 |
+
num_dummy_experts = self.config.num_expert_models
|
213 |
+
self._init_moe_layers_with_dummy_experts(
|
214 |
+
self.model, k, num_dummy_experts, layer_dtype
|
215 |
+
)
|
216 |
+
|
217 |
+
self.config.model_type = "idefics2_moe"
|
218 |
+
|
219 |
+
def _init_base_model(self):
|
220 |
+
print ("init base model")
|
221 |
+
return PreTrainedModel(self.config)
|
222 |
+
|
223 |
+
def _init_moe_layers(self, base_model, expert_models, k, layer_dtype):
|
224 |
+
|
225 |
+
self.num_layers = len(base_model.model.text_model.layers)
|
226 |
+
|
227 |
+
for i in tqdm(range(self.num_layers)):
|
228 |
+
experts = []
|
229 |
+
for expert_model in expert_models:
|
230 |
+
expert = copy.deepcopy(base_model.model.text_model.layers[i].mlp).to(
|
231 |
+
dtype=layer_dtype
|
232 |
+
)
|
233 |
+
experts.append(expert)
|
234 |
+
|
235 |
+
gating_layer = GatingLayer(
|
236 |
+
input_dim=self.config.text_config.hidden_size,
|
237 |
+
num_experts=len(experts),
|
238 |
+
k=k,
|
239 |
+
layer_dtype=layer_dtype,
|
240 |
+
)
|
241 |
+
moe_layer = MoE(
|
242 |
+
input_dim=self.config.text_config.hidden_size,
|
243 |
+
experts=experts,
|
244 |
+
gating_layer=gating_layer,
|
245 |
+
config=self.config,
|
246 |
+
).to(dtype=layer_dtype)
|
247 |
+
|
248 |
+
self.model.model.text_model.layers[i] = ModifiedIdefics2DecoderLayer(
|
249 |
+
self.model.model.text_model.layers[i], moe_layer
|
250 |
+
).to(dtype=layer_dtype)
|
251 |
+
|
252 |
+
def _init_moe_layers_with_dummy_experts(
|
253 |
+
self, base_model, k, num_dummy_experts, layer_dtype
|
254 |
+
):
|
255 |
+
self.num_layers = len(base_model.model.text_model.layers)
|
256 |
+
|
257 |
+
for i in tqdm(range(self.num_layers)):
|
258 |
+
experts = []
|
259 |
+
for _ in range(num_dummy_experts):
|
260 |
+
dummy_expert = Idefics2MLP( hidden_size=self.config.text_config.hidden_size,
|
261 |
+
intermediate_size=self.config.text_config.intermediate_size,
|
262 |
+
output_size=self.config.text_config.hidden_size,
|
263 |
+
hidden_act=self.config.perceiver_config.hidden_act,
|
264 |
+
).to(dtype=layer_dtype)
|
265 |
+
|
266 |
+
experts.append(dummy_expert)
|
267 |
+
|
268 |
+
gating_layer = GatingLayer(
|
269 |
+
input_dim=self.config.text_config.hidden_size,
|
270 |
+
num_experts=len(experts),
|
271 |
+
k=k,
|
272 |
+
layer_dtype=layer_dtype,
|
273 |
+
)
|
274 |
+
moe_layer = MoE(
|
275 |
+
input_dim=self.config.text_config.hidden_size,
|
276 |
+
experts=experts,
|
277 |
+
gating_layer=gating_layer,
|
278 |
+
config=self.config,
|
279 |
+
).to(dtype=layer_dtype)
|
280 |
+
|
281 |
+
self.model.model.text_model.layers[i] = ModifiedIdefics2DecoderLayer(
|
282 |
+
self.model.model.text_model.layers[i], moe_layer
|
283 |
+
).to(dtype=layer_dtype)
|
284 |
+
|
285 |
+
def get_input_embeddings(self):
|
286 |
+
if hasattr(self.model, "text_model"):
|
287 |
+
return self.model.text_model.get_input_embeddings()
|
288 |
+
return self.model.model.text_model.get_input_embeddings()
|
289 |
+
|
290 |
+
def forward(self, *args, **kwargs):
|
291 |
+
return self.model.forward(*args, **kwargs)
|
292 |
+
|
293 |
+
def generate(self, *args, **kwargs):
|
294 |
+
return self.model.generate(*args, **kwargs)
|
295 |
+
|
296 |
+
|
297 |
+
@classmethod
|
298 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
299 |
+
model = super(Idefics2ForCausalLMMoE, cls).from_pretrained(
|
300 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
301 |
+
)
|
302 |
+
return model
|
303 |
+
|
304 |
+
def plot_loss_histories(self, all_loss_histories, loss_steps, filename="loss_history.svg"):
|
305 |
+
plt.figure(figsize=(12, 8))
|
306 |
+
for layer_idx, loss_history in enumerate(all_loss_histories):
|
307 |
+
plt.plot(
|
308 |
+
range(0, len(loss_history) * loss_steps, loss_steps),
|
309 |
+
loss_history,
|
310 |
+
label=f'Layer {layer_idx}',
|
311 |
+
linewidth=2, # Thicker line
|
312 |
+
marker='o' # Circle marker for each data point
|
313 |
+
)
|
314 |
+
plt.xlabel('Epoch')
|
315 |
+
plt.ylabel('Loss')
|
316 |
+
plt.title('Loss History per Layer, MoE Gating Network')
|
317 |
+
plt.legend()
|
318 |
+
plt.grid(True)
|
319 |
+
try:
|
320 |
+
plt.savefig(filename)
|
321 |
+
except:
|
322 |
+
print("Figure file save failed...")
|
323 |
+
plt.show()
|
324 |
+
|
325 |
+
def train_gating_layer_params_from_hidden_states(self, processor, prompts_per_expert, epochs=1000, loss_steps=100,
|
326 |
+
lr=1e-4, layer_offset=0):
|
327 |
+
self.to(self.custom_device)
|
328 |
+
self.eval()
|
329 |
+
|
330 |
+
print ('dtype:', self.layer_dtype, 'device=', self.custom_device)
|
331 |
+
|
332 |
+
all_gating_layer_params = []
|
333 |
+
all_loss_histories = [] # To store loss histories for each layer
|
334 |
+
|
335 |
+
expert_hidden_states_per_layer = [[] for _ in range(self.num_layers)]
|
336 |
+
|
337 |
+
# Collect hidden states for each expert
|
338 |
+
for prompts in tqdm(prompts_per_expert, desc="Processing Prompts"):
|
339 |
+
for prompt in tqdm(prompts, desc="Processing Single Prompt", leave=False):
|
340 |
+
inputs = processor(text=prompt['text'], images=prompt['image'], return_tensors="pt").to(self.custom_device).to(self.layer_dtype)
|
341 |
+
with torch.no_grad():
|
342 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
343 |
+
hidden_states = outputs.hidden_states
|
344 |
+
for layer_idx in tqdm(range(self.num_layers)):
|
345 |
+
hidden_state = hidden_states[layer_idx+layer_offset].mean(dim=1) # Averaging over the sequence dimension
|
346 |
+
|
347 |
+
expert_hidden_states_per_layer[layer_idx].append(hidden_state)
|
348 |
+
|
349 |
+
# Train the gating layers
|
350 |
+
for layer_idx in tqdm(range(self.num_layers), desc="Training Gating Layers"):
|
351 |
+
print(f"Training gating layer parameters for layer {layer_idx}")
|
352 |
+
|
353 |
+
# Ensure we have hidden states collected for the current layer
|
354 |
+
if not expert_hidden_states_per_layer[layer_idx]:
|
355 |
+
raise ValueError(f"No hidden states collected for layer {layer_idx}")
|
356 |
+
|
357 |
+
# Aggregate hidden states for each expert and stack them
|
358 |
+
expert_hidden_states = []
|
359 |
+
num_prompts_per_expert = len(prompts_per_expert[0])
|
360 |
+
for i in range(len(prompts_per_expert)):
|
361 |
+
hidden_states_for_expert = expert_hidden_states_per_layer[layer_idx][i * num_prompts_per_expert: (i + 1) * num_prompts_per_expert]
|
362 |
+
hidden_state_avg = torch.stack(hidden_states_for_expert).mean(dim=0)
|
363 |
+
expert_hidden_states.append(hidden_state_avg)
|
364 |
+
expert_hidden_states = torch.stack(expert_hidden_states).to(self.layer_dtype)
|
365 |
+
|
366 |
+
input_dim = self.config.text_config.hidden_size
|
367 |
+
num_experts = self.config.num_expert_models
|
368 |
+
class SimpleGatingLayer(nn.Module):
|
369 |
+
def __init__(self, input_dim, num_experts, layer_dtype=torch.bfloat16):
|
370 |
+
super(SimpleGatingLayer, self).__init__()
|
371 |
+
self.gate = nn.Linear(input_dim, num_experts).to(dtype=layer_dtype)
|
372 |
+
|
373 |
+
def forward(self, x):
|
374 |
+
#return torch.softmax(self.gate(x), dim=-1)
|
375 |
+
return self.gate(x)
|
376 |
+
|
377 |
+
gating_layer = SimpleGatingLayer(self.config.text_config.hidden_size,
|
378 |
+
num_experts, layer_dtype=self.layer_dtype).to(self.custom_device)
|
379 |
+
|
380 |
+
criterion = nn.CrossEntropyLoss()
|
381 |
+
optimizer = Adam(gating_layer.parameters(), lr=lr)
|
382 |
+
|
383 |
+
loss_history = []
|
384 |
+
|
385 |
+
for epoch in tqdm(range(epochs), desc=f"Training Gating Layer {layer_idx}"):
|
386 |
+
optimizer.zero_grad()
|
387 |
+
# Reshape expert_hidden_states to match (batch_size, input_dim)
|
388 |
+
expert_hidden_states_reshaped = expert_hidden_states.view(-1, input_dim)
|
389 |
+
outputs = gating_layer(expert_hidden_states_reshaped)
|
390 |
+
labels = torch.arange(num_experts).to(self.custom_device)
|
391 |
+
|
392 |
+
loss = criterion(outputs, labels)
|
393 |
+
loss.backward()
|
394 |
+
optimizer.step()
|
395 |
+
|
396 |
+
if epoch % loss_steps == 0:
|
397 |
+
loss_history.append(loss.item())
|
398 |
+
|
399 |
+
all_loss_histories.append(loss_history)
|
400 |
+
all_gating_layer_params.append(gating_layer.state_dict())
|
401 |
+
|
402 |
+
self.plot_loss_histories(all_loss_histories, loss_steps)
|
403 |
+
return all_gating_layer_params
|
404 |
+
def set_gating_layer_params(self, gating_layer_params):
|
405 |
+
for layer_idx, params in enumerate(gating_layer_params):
|
406 |
+
self.model.model.text_model.layers[layer_idx].mlp.gating_layer.load_state_dict(params)
|
407 |
+
|
408 |
+
def freeze_except_gating_layers(model):
|
409 |
+
|
410 |
+
# Freeze all parameters
|
411 |
+
for param in model.parameters():
|
412 |
+
param.requires_grad = False
|
413 |
+
|
414 |
+
# Unfreeze gating layer parameters
|
415 |
+
for layer in model.model.model.text_model.layers:
|
416 |
+
for name, param in layer.mlp.gating_layer.named_parameters():
|
417 |
+
param.requires_grad = True
|
418 |
+
|
419 |
+
def un_freeze_all(model):
|
420 |
+
# Freeze all parameters
|
421 |
+
for param in model.parameters():
|
422 |
+
param.requires_grad = True
|
423 |
+
|
424 |
+
from transformers import AutoConfig
|
425 |
+
|
426 |
+
AutoConfig.register("idefics2_moe", Idefics2ForCausalLMMoEConfig)
|
427 |
+
|
428 |
+
from transformers.models.auto.modeling_auto import MODEL_MAPPING
|
429 |
+
|
430 |
+
MODEL_MAPPING.update({"idefics2_moe": Idefics2ForCausalLMMoE})
|