mjbuehler commited on
Commit
cd4062f
·
verified ·
1 Parent(s): 6241b4b

Upload moe_idefics2.py

Browse files
Files changed (1) hide show
  1. moe_idefics2.py +430 -0
moe_idefics2.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, Idefics2ForConditionalGeneration
2
+ from transformers import Cache
3
+ from transformers import Idefics2Model
4
+ from transformers.models.idefics2.modeling_idefics2 import Idefics2MLP
5
+ from transformers.models.idefics2.modeling_idefics2 import Idefics2Config, Idefics2PreTrainedModel
6
+
7
+ from tqdm.notebook import tqdm
8
+ import torch
9
+ import torch.nn as nn
10
+ import copy
11
+
12
+ from torch.optim import Adam
13
+ from typing import Optional, Tuple
14
+
15
+ from transformers import (
16
+ PreTrainedModel,
17
+ AutoConfig,
18
+ )
19
+
20
+ import torch.nn.functional as F
21
+
22
+ import matplotlib.pyplot as plt
23
+
24
+ # Define the Gating Layer
25
+ class GatingLayer(nn.Module):
26
+ def __init__(self, input_dim, num_experts, k, layer_dtype=torch.float16):
27
+ super(GatingLayer, self).__init__()
28
+ self.num_experts = num_experts
29
+ self.k = k
30
+ self.gate = nn.Linear(input_dim, num_experts).to(dtype=layer_dtype)
31
+
32
+ def forward(self, x):
33
+ gate_scores = torch.softmax(self.gate(x), dim=-1)
34
+ topk_values, topk_indices = torch.topk(gate_scores, self.k, dim=-1)
35
+ topk_values = F.softmax(topk_values, dim=-1)
36
+
37
+ return topk_values, topk_indices
38
+
39
+ class MoE(nn.Module):
40
+ def __init__(self, input_dim, experts, gating_layer, config):
41
+ super(MoE, self).__init__()
42
+ self.experts = nn.ModuleList(experts)
43
+ self.gating_layer = gating_layer
44
+ self.output_dim = config.text_config.hidden_size
45
+
46
+ def forward(self, x):
47
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
48
+ gate_values, gate_indices = self.gating_layer(x)
49
+ batch_size, seq_length, _ = x.size()
50
+
51
+ # Stack all expert parameters for efficient processing
52
+ expert_outputs = []
53
+ for expert in self.experts:
54
+ expert_outputs.append(expert.down_proj(expert.act_fn(expert.gate_proj(x)) * expert.up_proj(x)))
55
+ '''
56
+
57
+ up_states = expert.gate_up_proj(x.view(-1, x.size(-1))) # Flatten to [batch_size * seq_length, input_dim]
58
+ gate, up_states = up_states.chunk(2, dim=-1)
59
+ up_states = up_states * expert.activation_fn(gate)
60
+ expert_output = expert.down_proj(up_states)
61
+ expert_outputs.append(expert_output.view(batch_size, seq_length, -1))
62
+ '''
63
+
64
+ expert_outputs = torch.stack(expert_outputs, dim=-1) # Shape: [batch_size, seq_length, hidden_size, num_experts]
65
+
66
+ # Use torch.gather to select the expert outputs based on gate_indices
67
+ expanded_gate_indices = gate_indices.unsqueeze(-2).expand(-1, -1, self.output_dim, -1) # Shape: [batch_size, seq_length, hidden_size, k]
68
+ selected_expert_outputs = torch.gather(expert_outputs, -1, expanded_gate_indices) # Shape: [batch_size, seq_length, hidden_size, k]
69
+
70
+ # Weight the selected expert outputs by gate values
71
+ gate_values = gate_values.unsqueeze(-2) # Shape: [batch_size, seq_length, 1, k]
72
+ weighted_expert_outputs = selected_expert_outputs * gate_values # Shape: [batch_size, seq_length, hidden_size, k]
73
+
74
+ # Sum the weighted expert outputs across the k dimension
75
+ moe_output = weighted_expert_outputs.sum(dim=-1) # Shape: [batch_size, seq_length, hidden_size]
76
+
77
+ return moe_output.to(self.gating_layer.gate.weight.dtype)
78
+
79
+ class ModifiedIdefics2DecoderLayer(nn.Module):
80
+
81
+ def __init__(self, original_layer, moe_layer):
82
+ super(ModifiedIdefics2DecoderLayer, self).__init__()
83
+ self.self_attn = original_layer.self_attn
84
+ self.mlp = moe_layer
85
+ self.input_layernorm = original_layer.input_layernorm
86
+ self.post_attention_layernorm = original_layer.post_attention_layernorm
87
+ #print ("Init: ModifiedIdefics2DecoderLayer")
88
+
89
+ def forward(
90
+ self,
91
+ hidden_states: torch.Tensor,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ position_ids: Optional[torch.LongTensor] = None,
94
+ past_key_value: Optional[Cache] = None,
95
+ output_attentions: Optional[bool] = False,
96
+ use_cache: Optional[bool] = False,
97
+ cache_position: Optional[torch.LongTensor] = None,
98
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
99
+ """
100
+ Args:
101
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
102
+ attention_mask (`torch.FloatTensor`, *optional*):
103
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
104
+ query_sequence_length, key_sequence_length)` if default attention is used.
105
+ output_attentions (`bool`, *optional*):
106
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
107
+ returned tensors for more detail.
108
+ use_cache (`bool`, *optional*):
109
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
110
+ (see `past_key_values`).
111
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
112
+ """
113
+
114
+ residual = hidden_states
115
+
116
+ hidden_states = self.input_layernorm(hidden_states)
117
+
118
+ # Self Attention
119
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
120
+ hidden_states=hidden_states,
121
+ attention_mask=attention_mask,
122
+ position_ids=position_ids,
123
+ past_key_value=past_key_value,
124
+ output_attentions=output_attentions,
125
+ use_cache=use_cache,
126
+ ### MJB TODO: Need to check
127
+ #cache_position=cache_position,
128
+ )
129
+ hidden_states = residual + hidden_states
130
+
131
+ # Fully Connected
132
+ residual = hidden_states
133
+ hidden_states = self.post_attention_layernorm(hidden_states)
134
+ hidden_states = self.mlp(hidden_states)
135
+ hidden_states = residual + hidden_states
136
+
137
+ outputs = (hidden_states,)
138
+
139
+ if output_attentions:
140
+ outputs += (self_attn_weights,)
141
+
142
+ if use_cache:
143
+ outputs += (present_key_value,)
144
+
145
+ return outputs
146
+
147
+ #Define Idefics2ForCausalLMMoEConfig
148
+ class Idefics2ForCausalLMMoEConfig(Idefics2Config):
149
+ model_type = "idefics2_moe"
150
+
151
+ def __init__(self, config=None, k=1, num_expert_models=2, use_embeddings_in_router=False, **kwargs):
152
+ if config is not None:
153
+ kwargs.update(config.to_dict())
154
+ super().__init__(**kwargs)
155
+ self.k = k
156
+ self.num_expert_models = num_expert_models
157
+ self.architectures = "Idefics2ForCausalLMMoE"
158
+ self.auto_map = {
159
+ "AutoConfig": "moe_idefics2.Idefics2ForCausalLMMoEConfig",
160
+ "AutoModelForCausalLM": "moe_idefics2.Idefics2ForCausalLMMoE",
161
+ }
162
+ self.use_embeddings_in_router=use_embeddings_in_router
163
+
164
+
165
+ #Define MoE Model
166
+ class Idefics2ForCausalLMMoE(Idefics2ForConditionalGeneration):
167
+
168
+ config_class = Idefics2ForCausalLMMoEConfig
169
+
170
+ def __init__(
171
+ self,
172
+ config,
173
+ base_model=None,
174
+ expert_models=None,
175
+ layer_dtype=torch.bfloat16,
176
+ **kwargs,
177
+ ):
178
+ super().__init__(config)
179
+
180
+ self.layer_dtype = layer_dtype
181
+ self.custom_device = torch.device(
182
+ "cuda" if torch.cuda.is_available() else "cpu"
183
+ )
184
+ k = self.config.k
185
+
186
+ self.num_layers = len(base_model.model.text_model.layers) if base_model else 0
187
+
188
+ self.config.auto_map = {
189
+ "AutoConfig": "moe_idefics2.Idefics2ForCausalLMMoEConfig",
190
+ "AutoModelForCausalLM": "moe_idefics2.Idefics2ForCausalLMMoE",
191
+ }
192
+
193
+ self.use_embeddings_in_router=config.use_embeddings_in_router
194
+ print ("Use embeddigs in router: ", self.use_embeddings_in_router )
195
+
196
+ self.model = base_model or Idefics2ForConditionalGeneration(
197
+ self.config
198
+ )
199
+
200
+ if base_model and expert_models:
201
+ self.num_expert_models = len(expert_models)
202
+ self._init_moe_layers(base_model, expert_models, k, layer_dtype)
203
+ print ("CONSTRUCTOR self.model",self.model)
204
+ else:
205
+
206
+ print(
207
+ "Init function called and generating dummy experts: k=",
208
+ k,
209
+ "experts=",
210
+ self.config.num_expert_models,
211
+ )
212
+ num_dummy_experts = self.config.num_expert_models
213
+ self._init_moe_layers_with_dummy_experts(
214
+ self.model, k, num_dummy_experts, layer_dtype
215
+ )
216
+
217
+ self.config.model_type = "idefics2_moe"
218
+
219
+ def _init_base_model(self):
220
+ print ("init base model")
221
+ return PreTrainedModel(self.config)
222
+
223
+ def _init_moe_layers(self, base_model, expert_models, k, layer_dtype):
224
+
225
+ self.num_layers = len(base_model.model.text_model.layers)
226
+
227
+ for i in tqdm(range(self.num_layers)):
228
+ experts = []
229
+ for expert_model in expert_models:
230
+ expert = copy.deepcopy(base_model.model.text_model.layers[i].mlp).to(
231
+ dtype=layer_dtype
232
+ )
233
+ experts.append(expert)
234
+
235
+ gating_layer = GatingLayer(
236
+ input_dim=self.config.text_config.hidden_size,
237
+ num_experts=len(experts),
238
+ k=k,
239
+ layer_dtype=layer_dtype,
240
+ )
241
+ moe_layer = MoE(
242
+ input_dim=self.config.text_config.hidden_size,
243
+ experts=experts,
244
+ gating_layer=gating_layer,
245
+ config=self.config,
246
+ ).to(dtype=layer_dtype)
247
+
248
+ self.model.model.text_model.layers[i] = ModifiedIdefics2DecoderLayer(
249
+ self.model.model.text_model.layers[i], moe_layer
250
+ ).to(dtype=layer_dtype)
251
+
252
+ def _init_moe_layers_with_dummy_experts(
253
+ self, base_model, k, num_dummy_experts, layer_dtype
254
+ ):
255
+ self.num_layers = len(base_model.model.text_model.layers)
256
+
257
+ for i in tqdm(range(self.num_layers)):
258
+ experts = []
259
+ for _ in range(num_dummy_experts):
260
+ dummy_expert = Idefics2MLP( hidden_size=self.config.text_config.hidden_size,
261
+ intermediate_size=self.config.text_config.intermediate_size,
262
+ output_size=self.config.text_config.hidden_size,
263
+ hidden_act=self.config.perceiver_config.hidden_act,
264
+ ).to(dtype=layer_dtype)
265
+
266
+ experts.append(dummy_expert)
267
+
268
+ gating_layer = GatingLayer(
269
+ input_dim=self.config.text_config.hidden_size,
270
+ num_experts=len(experts),
271
+ k=k,
272
+ layer_dtype=layer_dtype,
273
+ )
274
+ moe_layer = MoE(
275
+ input_dim=self.config.text_config.hidden_size,
276
+ experts=experts,
277
+ gating_layer=gating_layer,
278
+ config=self.config,
279
+ ).to(dtype=layer_dtype)
280
+
281
+ self.model.model.text_model.layers[i] = ModifiedIdefics2DecoderLayer(
282
+ self.model.model.text_model.layers[i], moe_layer
283
+ ).to(dtype=layer_dtype)
284
+
285
+ def get_input_embeddings(self):
286
+ if hasattr(self.model, "text_model"):
287
+ return self.model.text_model.get_input_embeddings()
288
+ return self.model.model.text_model.get_input_embeddings()
289
+
290
+ def forward(self, *args, **kwargs):
291
+ return self.model.forward(*args, **kwargs)
292
+
293
+ def generate(self, *args, **kwargs):
294
+ return self.model.generate(*args, **kwargs)
295
+
296
+
297
+ @classmethod
298
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
299
+ model = super(Idefics2ForCausalLMMoE, cls).from_pretrained(
300
+ pretrained_model_name_or_path, *model_args, **kwargs
301
+ )
302
+ return model
303
+
304
+ def plot_loss_histories(self, all_loss_histories, loss_steps, filename="loss_history.svg"):
305
+ plt.figure(figsize=(12, 8))
306
+ for layer_idx, loss_history in enumerate(all_loss_histories):
307
+ plt.plot(
308
+ range(0, len(loss_history) * loss_steps, loss_steps),
309
+ loss_history,
310
+ label=f'Layer {layer_idx}',
311
+ linewidth=2, # Thicker line
312
+ marker='o' # Circle marker for each data point
313
+ )
314
+ plt.xlabel('Epoch')
315
+ plt.ylabel('Loss')
316
+ plt.title('Loss History per Layer, MoE Gating Network')
317
+ plt.legend()
318
+ plt.grid(True)
319
+ try:
320
+ plt.savefig(filename)
321
+ except:
322
+ print("Figure file save failed...")
323
+ plt.show()
324
+
325
+ def train_gating_layer_params_from_hidden_states(self, processor, prompts_per_expert, epochs=1000, loss_steps=100,
326
+ lr=1e-4, layer_offset=0):
327
+ self.to(self.custom_device)
328
+ self.eval()
329
+
330
+ print ('dtype:', self.layer_dtype, 'device=', self.custom_device)
331
+
332
+ all_gating_layer_params = []
333
+ all_loss_histories = [] # To store loss histories for each layer
334
+
335
+ expert_hidden_states_per_layer = [[] for _ in range(self.num_layers)]
336
+
337
+ # Collect hidden states for each expert
338
+ for prompts in tqdm(prompts_per_expert, desc="Processing Prompts"):
339
+ for prompt in tqdm(prompts, desc="Processing Single Prompt", leave=False):
340
+ inputs = processor(text=prompt['text'], images=prompt['image'], return_tensors="pt").to(self.custom_device).to(self.layer_dtype)
341
+ with torch.no_grad():
342
+ outputs = self.model(**inputs, output_hidden_states=True)
343
+ hidden_states = outputs.hidden_states
344
+ for layer_idx in tqdm(range(self.num_layers)):
345
+ hidden_state = hidden_states[layer_idx+layer_offset].mean(dim=1) # Averaging over the sequence dimension
346
+
347
+ expert_hidden_states_per_layer[layer_idx].append(hidden_state)
348
+
349
+ # Train the gating layers
350
+ for layer_idx in tqdm(range(self.num_layers), desc="Training Gating Layers"):
351
+ print(f"Training gating layer parameters for layer {layer_idx}")
352
+
353
+ # Ensure we have hidden states collected for the current layer
354
+ if not expert_hidden_states_per_layer[layer_idx]:
355
+ raise ValueError(f"No hidden states collected for layer {layer_idx}")
356
+
357
+ # Aggregate hidden states for each expert and stack them
358
+ expert_hidden_states = []
359
+ num_prompts_per_expert = len(prompts_per_expert[0])
360
+ for i in range(len(prompts_per_expert)):
361
+ hidden_states_for_expert = expert_hidden_states_per_layer[layer_idx][i * num_prompts_per_expert: (i + 1) * num_prompts_per_expert]
362
+ hidden_state_avg = torch.stack(hidden_states_for_expert).mean(dim=0)
363
+ expert_hidden_states.append(hidden_state_avg)
364
+ expert_hidden_states = torch.stack(expert_hidden_states).to(self.layer_dtype)
365
+
366
+ input_dim = self.config.text_config.hidden_size
367
+ num_experts = self.config.num_expert_models
368
+ class SimpleGatingLayer(nn.Module):
369
+ def __init__(self, input_dim, num_experts, layer_dtype=torch.bfloat16):
370
+ super(SimpleGatingLayer, self).__init__()
371
+ self.gate = nn.Linear(input_dim, num_experts).to(dtype=layer_dtype)
372
+
373
+ def forward(self, x):
374
+ #return torch.softmax(self.gate(x), dim=-1)
375
+ return self.gate(x)
376
+
377
+ gating_layer = SimpleGatingLayer(self.config.text_config.hidden_size,
378
+ num_experts, layer_dtype=self.layer_dtype).to(self.custom_device)
379
+
380
+ criterion = nn.CrossEntropyLoss()
381
+ optimizer = Adam(gating_layer.parameters(), lr=lr)
382
+
383
+ loss_history = []
384
+
385
+ for epoch in tqdm(range(epochs), desc=f"Training Gating Layer {layer_idx}"):
386
+ optimizer.zero_grad()
387
+ # Reshape expert_hidden_states to match (batch_size, input_dim)
388
+ expert_hidden_states_reshaped = expert_hidden_states.view(-1, input_dim)
389
+ outputs = gating_layer(expert_hidden_states_reshaped)
390
+ labels = torch.arange(num_experts).to(self.custom_device)
391
+
392
+ loss = criterion(outputs, labels)
393
+ loss.backward()
394
+ optimizer.step()
395
+
396
+ if epoch % loss_steps == 0:
397
+ loss_history.append(loss.item())
398
+
399
+ all_loss_histories.append(loss_history)
400
+ all_gating_layer_params.append(gating_layer.state_dict())
401
+
402
+ self.plot_loss_histories(all_loss_histories, loss_steps)
403
+ return all_gating_layer_params
404
+ def set_gating_layer_params(self, gating_layer_params):
405
+ for layer_idx, params in enumerate(gating_layer_params):
406
+ self.model.model.text_model.layers[layer_idx].mlp.gating_layer.load_state_dict(params)
407
+
408
+ def freeze_except_gating_layers(model):
409
+
410
+ # Freeze all parameters
411
+ for param in model.parameters():
412
+ param.requires_grad = False
413
+
414
+ # Unfreeze gating layer parameters
415
+ for layer in model.model.model.text_model.layers:
416
+ for name, param in layer.mlp.gating_layer.named_parameters():
417
+ param.requires_grad = True
418
+
419
+ def un_freeze_all(model):
420
+ # Freeze all parameters
421
+ for param in model.parameters():
422
+ param.requires_grad = True
423
+
424
+ from transformers import AutoConfig
425
+
426
+ AutoConfig.register("idefics2_moe", Idefics2ForCausalLMMoEConfig)
427
+
428
+ from transformers.models.auto.modeling_auto import MODEL_MAPPING
429
+
430
+ MODEL_MAPPING.update({"idefics2_moe": Idefics2ForCausalLMMoE})