appledora commited on
Commit
1a17fc1
·
verified ·
1 Parent(s): 46b6da4

Upload 4 files

Browse files
recast_llama/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.utils import (
2
+ OptionalDependencyNotAvailable,
3
+ _LazyModule,
4
+ is_torch_available,
5
+ )
6
+
7
+ try:
8
+ if not is_torch_available():
9
+ raise OptionalDependencyNotAvailable()
10
+ except OptionalDependencyNotAvailable:
11
+ pass
12
+ else:
13
+ from .modeling_recast_llama import (
14
+ RECAST_llamaModel,
15
+ RECAST_LlamaForCausalLM,
16
+ )
17
+ from .configuration_recast_llama import RECAST_llama
18
+
19
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
20
+
21
+ # Register your models with Auto classes
22
+ AutoConfig.register("recast_llama", RECAST_llama)
23
+ AutoModel.register(RECAST_llama, RECAST_llamaModel)
24
+ AutoModelForCausalLM.register(RECAST_llama, RECAST_LlamaForCausalLM)
25
+
26
+ _import_structure = {
27
+ "configuration_recastmlp_llama": ["RECAST_llama"],
28
+ "modeling_recastmlp_llama": ["RECAST_llamaModel", "RECAST_LlamaForCausalLM"],
29
+ }
30
+
31
+ __all__ = ["RECAST_llamaModel", "RECAST_LlamaForCausalLM", "RECAST_llama"]
recast_llama/configuration_recast_llama.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class RECAST_llama(PretrainedConfig):
5
+ model_type = "recast_llama"
6
+ attribute_map = {
7
+ "hidden_size": "hidden_size",
8
+ "num_attention_heads": "num_attention_heads",
9
+ }
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=128256,
14
+ hidden_size=4096,
15
+ intermediate_size=14336,
16
+ num_hidden_layers=32,
17
+ num_attention_heads=32,
18
+ num_key_value_heads=8,
19
+ hidden_act="silu",
20
+ max_position_embeddings=131072,
21
+ initializer_range=0.02,
22
+ rms_norm_eps=1e-5,
23
+ use_cache=True,
24
+ pad_token_id=None,
25
+ bos_token_id=128000,
26
+ eos_token_id=128001,
27
+ pretraining_tp=1,
28
+ tie_word_embeddings=False,
29
+ rope_theta=500000.0,
30
+ rope_scaling={
31
+ "factor": 8.0,
32
+ "low_freq_factor": 1.0,
33
+ "high_freq_factor": 4.0,
34
+ "original_max_position_embeddings": 8192,
35
+ "rope_type": "llama3",
36
+ },
37
+ attention_bias=False,
38
+ attention_dropout=0.0,
39
+ mlp_bias=False,
40
+ # Template-specific configs
41
+ num_templates=4,
42
+ num_groups=8,
43
+ num_cf=1,
44
+ torch_dtype="bfloat16",
45
+ **kwargs
46
+ ):
47
+ self.vocab_size = vocab_size
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.hidden_size = hidden_size
50
+ self.intermediate_size = intermediate_size
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.num_attention_heads = num_attention_heads
53
+ self.num_key_value_heads = num_key_value_heads
54
+ self.hidden_act = hidden_act
55
+ self.initializer_range = initializer_range
56
+ self.rms_norm_eps = rms_norm_eps
57
+ self.pretraining_tp = pretraining_tp
58
+ self.use_cache = use_cache
59
+ self.mlp_bias = mlp_bias
60
+ self.attention_bias = attention_bias
61
+ self.attention_dropout = attention_dropout
62
+ self.rope_theta = rope_theta
63
+ self.rope_scaling = rope_scaling
64
+ self.torch_dtype = torch_dtype
65
+
66
+ # Template-specific configs
67
+ self.num_templates = num_templates
68
+ self.num_groups = num_groups
69
+ self.num_cf = num_cf
70
+
71
+ super().__init__(
72
+ pad_token_id=pad_token_id,
73
+ bos_token_id=bos_token_id,
74
+ eos_token_id=eos_token_id,
75
+ tie_word_embeddings=tie_word_embeddings,
76
+ **kwargs
77
+ )
recast_llama/modeling_recast_llama.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: recastmlp_llama_model.py
2
+ from .configuration_recast_llama import RECAST_llama
3
+ from transformers import PreTrainedModel
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional, Tuple, Union, List
8
+ from transformers import AutoConfig
9
+ from transformers.utils import logging
10
+ from transformers.cache_utils import Cache, StaticCache
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class MLPTemplateBank(nn.Module):
19
+ def __init__(self, config, num_templates):
20
+ super().__init__()
21
+ self.num_templates = num_templates
22
+ self.hidden_size = config.hidden_size
23
+ self.intermediate_size = config.intermediate_size
24
+
25
+ # Store templates in a more efficient layout
26
+ self.up_templates = nn.Parameter(
27
+ torch.empty(num_templates, self.intermediate_size * self.hidden_size)
28
+ )
29
+ self.gate_templates = nn.Parameter(
30
+ torch.empty(num_templates, self.intermediate_size * self.hidden_size)
31
+ )
32
+ self.down_templates = nn.Parameter(
33
+ torch.empty(num_templates, self.hidden_size * self.intermediate_size)
34
+ )
35
+
36
+ nn.init.kaiming_normal_(self.up_templates)
37
+ nn.init.kaiming_normal_(self.gate_templates)
38
+ nn.init.kaiming_normal_(self.down_templates)
39
+
40
+ def forward(self, up_coeffs, gate_coeffs, down_coeffs):
41
+ # Simple matrix multiplication instead of broadcasting
42
+ up_weights = torch.mm(up_coeffs, self.up_templates)
43
+ gate_weights = torch.mm(gate_coeffs, self.gate_templates)
44
+ down_weights = torch.mm(down_coeffs, self.down_templates)
45
+ up_weights = up_weights.view(self.intermediate_size, self.hidden_size)
46
+ gate_weights = gate_weights.view(self.intermediate_size, self.hidden_size)
47
+ down_weights = down_weights.view(self.hidden_size, self.intermediate_size)
48
+ return gate_weights, up_weights, down_weights
49
+
50
+
51
+ class SharedLlamaMLP(nn.Module):
52
+ def __init__(self, config, bank):
53
+ super().__init__()
54
+ self.config = config
55
+ self.bank = bank
56
+ self.hidden_size = config.hidden_size
57
+ self.intermediate_size = config.intermediate_size
58
+
59
+ # Use transposed coefficients to avoid unnecessary operations
60
+ # self.coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
61
+ # nn.init.normal_(self.coefficients, std=0.02)
62
+ self.up_coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
63
+ self.gate_coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
64
+ self.down_coefficients = nn.Parameter(torch.zeros(1, config.num_templates))
65
+ nn.init.orthogonal_(self.up_coefficients)
66
+ nn.init.orthogonal_(self.gate_coefficients)
67
+ nn.init.orthogonal_(self.down_coefficients)
68
+ if config.mlp_bias:
69
+ self.gate_bias = nn.Parameter(torch.zeros(self.intermediate_size))
70
+ self.up_bias = nn.Parameter(torch.zeros(self.intermediate_size))
71
+ self.down_bias = nn.Parameter(torch.zeros(self.hidden_size))
72
+ else:
73
+ self.register_parameter("gate_bias", None)
74
+ self.register_parameter("up_bias", None)
75
+ self.register_parameter("down_bias", None)
76
+
77
+ self.act_fn = F.silu
78
+
79
+ def forward(self, x):
80
+ # Generate weights with minimal operations
81
+ gate_weights, up_weights, down_weights = self.bank(
82
+ self.up_coefficients, self.gate_coefficients, self.down_coefficients
83
+ )
84
+
85
+ # Standard MLP operations
86
+ gate_output = F.linear(x, gate_weights, self.gate_bias)
87
+ up_output = F.linear(x, up_weights, self.up_bias)
88
+
89
+ hidden_states = self.act_fn(gate_output) * up_output
90
+ output = F.linear(hidden_states, down_weights, self.down_bias)
91
+
92
+ return output
93
+
94
+
95
+ def fixed_cross_entropy(
96
+ source,
97
+ target,
98
+ num_items_in_batch: int = None,
99
+ ignore_index: int = -100,
100
+ **kwargs,
101
+ ):
102
+ reduction = "sum" if num_items_in_batch is not None else "mean"
103
+ loss = nn.functional.cross_entropy(
104
+ source, target, ignore_index=ignore_index, reduction=reduction
105
+ )
106
+ if reduction == "sum":
107
+ loss = loss / num_items_in_batch
108
+ return loss
109
+
110
+
111
+ from transformers.models.llama.modeling_llama import (
112
+ LlamaDecoderLayer,
113
+ LlamaRotaryEmbedding,
114
+ LlamaRMSNorm,
115
+ apply_rotary_pos_emb,
116
+ )
117
+ from transformers.modeling_outputs import BaseModelOutputWithPast
118
+
119
+
120
+ class RECAST_llamaModel(PreTrainedModel):
121
+ config_class = RECAST_llama
122
+ base_model_prefix = "llama"
123
+ supports_gradient_checkpointing = True
124
+
125
+ def __init__(self, config):
126
+ super().__init__(config)
127
+ self.padding_idx = config.pad_token_id
128
+ self.vocab_size = config.vocab_size
129
+
130
+ self.embed_tokens = nn.Embedding(
131
+ config.vocab_size, config.hidden_size, self.padding_idx
132
+ )
133
+ # Initialize rotary embeddings
134
+ rope_config = config.rope_scaling
135
+ if rope_config:
136
+ rope_type = rope_config.get("rope_type", "default")
137
+ scaling_factor = rope_config.get("factor", 1.0)
138
+ else:
139
+ rope_type = "default"
140
+ scaling_factor = None
141
+ original_config = AutoConfig.from_pretrained(
142
+ "meta-llama/Llama-3.1-8b", trust_remote_code=True
143
+ )
144
+ self.rotary_emb = LlamaRotaryEmbedding(
145
+ config=original_config,
146
+ )
147
+
148
+ # Create template banks first
149
+ self.banks = []
150
+ layers_per_group = config.num_hidden_layers // config.num_groups
151
+ for _ in range(config.num_groups):
152
+ bank = MLPTemplateBank(config, config.num_templates)
153
+ self.banks.append(bank)
154
+
155
+ # Create layers using LlamaDecoderLayer but replace MLPs
156
+ self.layers = nn.ModuleList()
157
+ for layer_idx in range(config.num_hidden_layers):
158
+ # Create standard LlamaDecoderLayer
159
+ decoder_layer = LlamaDecoderLayer(config, layer_idx)
160
+
161
+ # Replace its MLP with our SharedLlamaMLP
162
+ group_idx = layer_idx // layers_per_group
163
+ group_bank = self.banks[group_idx]
164
+ decoder_layer.mlp = SharedLlamaMLP(config, bank=group_bank)
165
+
166
+ self.layers.append(decoder_layer)
167
+
168
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
169
+ self.gradient_checkpointing = False
170
+
171
+ def forward(
172
+ self,
173
+ input_ids: torch.LongTensor = None,
174
+ attention_mask: Optional[torch.Tensor] = None,
175
+ position_ids: Optional[torch.LongTensor] = None,
176
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
177
+ inputs_embeds: Optional[torch.FloatTensor] = None,
178
+ use_cache: Optional[bool] = None,
179
+ output_attentions: Optional[bool] = None,
180
+ output_hidden_states: Optional[bool] = None,
181
+ return_dict: Optional[bool] = None,
182
+ cache_position: Optional[torch.LongTensor] = None,
183
+ **flash_attn_kwargs,
184
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
185
+ output_attentions = (
186
+ output_attentions
187
+ if output_attentions is not None
188
+ else self.config.output_attentions
189
+ )
190
+ output_hidden_states = (
191
+ output_hidden_states
192
+ if output_hidden_states is not None
193
+ else self.config.output_hidden_states
194
+ )
195
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
196
+ return_dict = (
197
+ return_dict if return_dict is not None else self.config.use_return_dict
198
+ )
199
+
200
+ if (input_ids is None) ^ (inputs_embeds is not None):
201
+ raise ValueError(
202
+ "You must specify exactly one of input_ids or inputs_embeds"
203
+ )
204
+
205
+ if self.gradient_checkpointing and self.training and use_cache:
206
+ logger.warning_once(
207
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
208
+ )
209
+ use_cache = False
210
+
211
+ if inputs_embeds is None:
212
+ inputs_embeds = self.embed_tokens(input_ids)
213
+ # Set up cache position if not provided
214
+ if cache_position is None:
215
+ past_seen_tokens = (
216
+ 0
217
+ if past_key_values is None
218
+ else (
219
+ past_key_values.get_seq_length()
220
+ if isinstance(past_key_values, Cache)
221
+ else past_key_values[0][0].size(-2) if past_key_values else 0
222
+ )
223
+ )
224
+ cache_position = torch.arange(
225
+ past_seen_tokens,
226
+ past_seen_tokens + inputs_embeds.shape[1],
227
+ device=inputs_embeds.device,
228
+ )
229
+ # Create position embeddings to be shared across the decoder layers
230
+ # Set up position IDs if not provided
231
+ if position_ids is None:
232
+ position_ids = cache_position.unsqueeze(0)
233
+ # Get updated causal mask
234
+ causal_mask = self._update_causal_mask(
235
+ attention_mask,
236
+ inputs_embeds,
237
+ cache_position,
238
+ past_key_values,
239
+ output_attentions,
240
+ )
241
+ hidden_states = inputs_embeds
242
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
243
+
244
+ # Initialize outputs
245
+ all_hidden_states = () if output_hidden_states else None
246
+ all_self_attns = () if output_attentions else None
247
+ next_decoder_cache = None
248
+
249
+ # Process through layers
250
+ for decoder_layer in self.layers:
251
+ if output_hidden_states:
252
+ all_hidden_states += (hidden_states,)
253
+
254
+ if self.gradient_checkpointing and self.training:
255
+ layer_outputs = self._gradient_checkpointing_func(
256
+ decoder_layer.__call__,
257
+ hidden_states,
258
+ causal_mask,
259
+ position_ids,
260
+ past_key_values,
261
+ output_attentions,
262
+ use_cache,
263
+ position_embeddings,
264
+ )
265
+ else:
266
+ layer_outputs = decoder_layer(
267
+ hidden_states,
268
+ attention_mask=causal_mask,
269
+ position_ids=position_ids,
270
+ past_key_value=past_key_values,
271
+ output_attentions=output_attentions,
272
+ use_cache=use_cache,
273
+ position_embeddings=position_embeddings,
274
+ **flash_attn_kwargs,
275
+ )
276
+
277
+ hidden_states = layer_outputs[0]
278
+
279
+ if use_cache:
280
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
281
+
282
+ if output_attentions:
283
+ all_self_attns += (layer_outputs[1],)
284
+
285
+ # Final layer norm
286
+ hidden_states = self.norm(hidden_states)
287
+
288
+ # Add last hidden state
289
+ if output_hidden_states:
290
+ all_hidden_states += (hidden_states,)
291
+
292
+ next_cache = next_decoder_cache if use_cache else None
293
+
294
+ if not return_dict:
295
+ return tuple(
296
+ v
297
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
298
+ if v is not None
299
+ )
300
+
301
+ return BaseModelOutputWithPast(
302
+ last_hidden_state=hidden_states,
303
+ past_key_values=next_cache,
304
+ hidden_states=all_hidden_states,
305
+ attentions=all_self_attns,
306
+ )
307
+
308
+ @classmethod
309
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
310
+ if isinstance(
311
+ pretrained_model_name_or_path, str
312
+ ) and pretrained_model_name_or_path.endswith(".pt"):
313
+ print("Loading from local checkpoint")
314
+ # Load from local checkpoint
315
+ config = kwargs.get("config", None)
316
+ if config is None:
317
+ config = AutoConfig.from_pretrained(
318
+ pretrained_model_name_or_path, trust_remote_code=True
319
+ )
320
+
321
+ model = cls(config)
322
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
323
+ state_dict = checkpoint["model_state_dict"]
324
+ logger.info(
325
+ f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
326
+ )
327
+
328
+ missing_keys, unexpected_keys = model.load_state_dict(
329
+ state_dict, strict=False
330
+ )
331
+
332
+ if len(missing_keys) > 0:
333
+ logger.warning(f"Missing keys: {missing_keys}")
334
+ if len(unexpected_keys) > 0:
335
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
336
+
337
+ return model
338
+ else:
339
+ print("Loading from hub")
340
+ # Load from hub using parent's from_pretrained
341
+ return super().from_pretrained(
342
+ pretrained_model_name_or_path, *model_args, **kwargs
343
+ )
344
+
345
+ def get_input_embeddings(self):
346
+ return self.embed_tokens
347
+
348
+ def set_input_embeddings(self, value):
349
+ self.embed_tokens = value
350
+
351
+ def _update_causal_mask(
352
+ self,
353
+ attention_mask: torch.Tensor,
354
+ input_tensor: torch.Tensor,
355
+ cache_position: torch.Tensor,
356
+ past_key_values: Cache,
357
+ output_attentions: bool,
358
+ ):
359
+ if self.config._attn_implementation == "flash_attention_2":
360
+ if attention_mask is not None and 0.0 in attention_mask:
361
+ return attention_mask
362
+ return None
363
+
364
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
365
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
366
+ # to infer the attention mask.
367
+ past_seen_tokens = (
368
+ past_key_values.get_seq_length() if past_key_values is not None else 0
369
+ )
370
+ using_static_cache = isinstance(past_key_values, StaticCache)
371
+
372
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
373
+ if (
374
+ self.config._attn_implementation == "sdpa"
375
+ and not using_static_cache
376
+ and not output_attentions
377
+ ):
378
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
379
+ attention_mask,
380
+ inputs_embeds=input_tensor,
381
+ past_key_values_length=past_seen_tokens,
382
+ is_training=self.training,
383
+ ):
384
+ return None
385
+
386
+ dtype, device = input_tensor.dtype, input_tensor.device
387
+ sequence_length = input_tensor.shape[1]
388
+ if using_static_cache:
389
+ target_length = past_key_values.get_max_cache_shape()
390
+ else:
391
+ target_length = (
392
+ attention_mask.shape[-1]
393
+ if isinstance(attention_mask, torch.Tensor)
394
+ else past_seen_tokens + sequence_length + 1
395
+ )
396
+
397
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
398
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
399
+ attention_mask,
400
+ sequence_length=sequence_length,
401
+ target_length=target_length,
402
+ dtype=dtype,
403
+ device=device,
404
+ cache_position=cache_position,
405
+ batch_size=input_tensor.shape[0],
406
+ )
407
+
408
+ if (
409
+ self.config._attn_implementation == "sdpa"
410
+ and attention_mask is not None
411
+ and attention_mask.device.type == "cuda"
412
+ and not output_attentions
413
+ ):
414
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
415
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
416
+ # Details: https://github.com/pytorch/pytorch/issues/110213
417
+ min_dtype = torch.finfo(dtype).min
418
+ causal_mask = AttentionMaskConverter._unmask_unattended(
419
+ causal_mask, min_dtype
420
+ )
421
+
422
+ return causal_mask
423
+
424
+ @staticmethod
425
+ def _prepare_4d_causal_attention_mask_with_cache_position(
426
+ attention_mask: torch.Tensor,
427
+ sequence_length: int,
428
+ target_length: int,
429
+ dtype: torch.dtype,
430
+ device: torch.device,
431
+ cache_position: torch.Tensor,
432
+ batch_size: int,
433
+ **kwargs,
434
+ ):
435
+ if attention_mask is not None and attention_mask.dim() == 4:
436
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
437
+ causal_mask = attention_mask
438
+ else:
439
+ min_dtype = torch.finfo(dtype).min
440
+ causal_mask = torch.full(
441
+ (sequence_length, target_length),
442
+ fill_value=min_dtype,
443
+ dtype=dtype,
444
+ device=device,
445
+ )
446
+ if sequence_length != 1:
447
+ causal_mask = torch.triu(causal_mask, diagonal=1)
448
+ causal_mask *= torch.arange(
449
+ target_length, device=device
450
+ ) > cache_position.reshape(-1, 1)
451
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
452
+ if attention_mask is not None:
453
+ causal_mask = (
454
+ causal_mask.clone()
455
+ ) # copy to contiguous memory for in-place edit
456
+ mask_length = attention_mask.shape[-1]
457
+ padding_mask = (
458
+ causal_mask[:, :, :, :mask_length]
459
+ + attention_mask[:, None, None, :]
460
+ )
461
+ padding_mask = padding_mask == 0
462
+ causal_mask[:, :, :, :mask_length] = causal_mask[
463
+ :, :, :, :mask_length
464
+ ].masked_fill(padding_mask, min_dtype)
465
+
466
+ return causal_mask
467
+
468
+
469
+ class RECAST_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
470
+ _tied_weights_keys = ["lm_head.weight"]
471
+ _tp_plan = {"lm_head": "colwise_rep"}
472
+ config_class = RECAST_llama
473
+ base_model_prefix = "llama"
474
+ supports_gradient_checkpointing = True
475
+
476
+ def __init__(self, config):
477
+ super().__init__(config)
478
+ self.model = RECAST_llamaModel(config)
479
+ self.vocab_size = config.vocab_size
480
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
481
+
482
+ # Initialize weights and apply final processing
483
+ self.post_init()
484
+
485
+ def get_input_embeddings(self):
486
+ return self.model.embed_tokens
487
+
488
+ def set_input_embeddings(self, value):
489
+ self.model.embed_tokens = value
490
+
491
+ def get_output_embeddings(self):
492
+ return self.lm_head
493
+
494
+ def set_output_embeddings(self, new_embeddings):
495
+ self.lm_head = new_embeddings
496
+
497
+ def set_decoder(self, decoder):
498
+ self.model = decoder
499
+
500
+ def get_decoder(self):
501
+ return self.model
502
+
503
+ def loss_function(
504
+ self,
505
+ logits,
506
+ labels,
507
+ vocab_size: int,
508
+ num_items_in_batch: int = None,
509
+ ignore_index: int = -100,
510
+ **kwargs,
511
+ ):
512
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
513
+ logits = logits.float()
514
+ # Shift so that tokens < n predict n
515
+ shift_logits = logits[..., :-1, :].contiguous()
516
+ shift_labels = labels[..., 1:].contiguous()
517
+ # Flatten the tokens
518
+ shift_logits = shift_logits.view(-1, vocab_size)
519
+ shift_labels = shift_labels.view(-1)
520
+ # Enable model parallelism
521
+ shift_labels = shift_labels.to(shift_logits.device)
522
+ loss = fixed_cross_entropy(
523
+ shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
524
+ )
525
+ return loss
526
+
527
+ def forward(
528
+ self,
529
+ input_ids: torch.LongTensor = None,
530
+ attention_mask: Optional[torch.Tensor] = None,
531
+ position_ids: Optional[torch.LongTensor] = None,
532
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
533
+ inputs_embeds: Optional[torch.FloatTensor] = None,
534
+ labels: Optional[torch.LongTensor] = None,
535
+ use_cache: Optional[bool] = None,
536
+ output_attentions: Optional[bool] = None,
537
+ output_hidden_states: Optional[bool] = None,
538
+ return_dict: Optional[bool] = None,
539
+ cache_position: Optional[torch.LongTensor] = None,
540
+ num_logits_to_keep: int = 0,
541
+ **kwargs,
542
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
543
+ """
544
+ Args:
545
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
546
+ Labels for computing the masked language modeling loss. Indices should be in
547
+ `[0, ..., config.vocab_size]` or -100 (masked tokens).
548
+ num_logits_to_keep (`int`, *optional*):
549
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
550
+ """
551
+ output_attentions = (
552
+ output_attentions
553
+ if output_attentions is not None
554
+ else self.config.output_attentions
555
+ )
556
+ output_hidden_states = (
557
+ output_hidden_states
558
+ if output_hidden_states is not None
559
+ else self.config.output_hidden_states
560
+ )
561
+ return_dict = (
562
+ return_dict if return_dict is not None else self.config.use_return_dict
563
+ )
564
+
565
+ outputs = self.model(
566
+ input_ids=input_ids,
567
+ attention_mask=attention_mask,
568
+ position_ids=position_ids,
569
+ past_key_values=past_key_values,
570
+ inputs_embeds=inputs_embeds,
571
+ use_cache=use_cache,
572
+ output_attentions=output_attentions,
573
+ output_hidden_states=output_hidden_states,
574
+ return_dict=return_dict,
575
+ cache_position=cache_position,
576
+ **kwargs,
577
+ )
578
+
579
+ hidden_states = outputs[0]
580
+ # Only compute necessary logits
581
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
582
+
583
+ loss = None
584
+ if labels is not None:
585
+ # Calculate batch size for loss function
586
+ num_items_in_batch = (
587
+ input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
588
+ )
589
+ loss = self.loss_function(
590
+ logits=logits,
591
+ labels=labels,
592
+ vocab_size=self.config.vocab_size,
593
+ num_items_in_batch=num_items_in_batch,
594
+ **kwargs,
595
+ )
596
+
597
+ if not return_dict:
598
+ output = (logits,) + outputs[1:]
599
+ return (loss,) + output if loss is not None else output
600
+
601
+ return CausalLMOutputWithPast(
602
+ loss=loss,
603
+ logits=logits,
604
+ past_key_values=outputs.past_key_values,
605
+ hidden_states=outputs.hidden_states,
606
+ attentions=outputs.attentions,
607
+ )
608
+
609
+ def prepare_inputs_for_generation(
610
+ self,
611
+ input_ids,
612
+ past_key_values=None,
613
+ attention_mask=None,
614
+ inputs_embeds=None,
615
+ **kwargs,
616
+ ):
617
+ if past_key_values:
618
+ input_ids = input_ids[:, -1:]
619
+
620
+ position_ids = kwargs.get("position_ids", None)
621
+ if attention_mask is not None and position_ids is None:
622
+ # create position_ids on the fly for batch generation
623
+ position_ids = attention_mask.long().cumsum(-1) - 1
624
+ position_ids.masked_fill_(attention_mask == 0, 1)
625
+ if past_key_values:
626
+ position_ids = position_ids[:, -1].unsqueeze(-1)
627
+
628
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
629
+ if inputs_embeds is not None and past_key_values is None:
630
+ model_inputs = {"inputs_embeds": inputs_embeds}
631
+ else:
632
+ model_inputs = {"input_ids": input_ids}
633
+
634
+ model_inputs.update(
635
+ {
636
+ "position_ids": position_ids,
637
+ "past_key_values": past_key_values,
638
+ "use_cache": kwargs.get("use_cache"),
639
+ "attention_mask": attention_mask,
640
+ }
641
+ )
642
+ return model_inputs
643
+
644
+ @classmethod
645
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
646
+ if isinstance(
647
+ pretrained_model_name_or_path, str
648
+ ) and pretrained_model_name_or_path.endswith(".pt"):
649
+ print("Loading from local checkpoint")
650
+ config = kwargs.get("config", None)
651
+ if config is None:
652
+ config = AutoConfig.from_pretrained(
653
+ pretrained_model_name_or_path, trust_remote_code=True
654
+ )
655
+
656
+ model = cls(config)
657
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
658
+ state_dict = checkpoint["model_state_dict"]
659
+
660
+ missing_keys, unexpected_keys = model.load_state_dict(
661
+ state_dict, strict=False
662
+ )
663
+
664
+ if len(missing_keys) > 0:
665
+ logger.warning(f"Missing keys: {missing_keys}")
666
+ if len(unexpected_keys) > 0:
667
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
668
+
669
+ return model
670
+ else:
671
+ print("Loading from hub")
672
+ return super().from_pretrained(
673
+ pretrained_model_name_or_path, *model_args, **kwargs
674
+ )