QscQ commited on
Commit
5fe30d5
·
1 Parent(s): b1645e4
Files changed (2) hide show
  1. configuration_minimax_m1.py +14 -14
  2. modeling_minimax_m1.py +57 -57
configuration_minimax_m1.py CHANGED
@@ -1,4 +1,4 @@
1
- """ MiniMaxText01 model configuration"""
2
 
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.utils import logging
@@ -7,11 +7,11 @@ from transformers.utils import logging
7
  logger = logging.get_logger(__name__)
8
 
9
 
10
- class MiniMaxText01Config(PretrainedConfig):
11
  r"""
12
- This is the configuration class to store the configuration of a [`MiniMaxText01Model`]. It is used to instantiate an
13
- MiniMaxText01 model according to the specified arguments, defining the model architecture. Instantiating a configuration
14
- with the defaults will yield a similar configuration to that of the MiniMaxText01.
15
 
16
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
  documentation from [`PretrainedConfig`] for more information.
@@ -19,8 +19,8 @@ class MiniMaxText01Config(PretrainedConfig):
19
 
20
  Args:
21
  vocab_size (`int`, *optional*, defaults to 32000):
22
- Vocabulary size of the MiniMaxText01 model. Defines the number of different tokens that can be represented by the
23
- `inputs_ids` passed when calling [`MiniMaxText01Model`]
24
  hidden_size (`int`, *optional*, defaults to 4096):
25
  Dimension of the hidden representations.
26
  intermediate_size (`int`, *optional*, defaults to 14336):
@@ -39,7 +39,7 @@ class MiniMaxText01Config(PretrainedConfig):
39
  hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
40
  The non-linear activation function (function or string) in the decoder.
41
  max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
42
- The maximum sequence length that this model might ever be used with. MiniMaxText01's sliding window attention
43
  allows sequence of up to 4096*32 tokens.
44
  initializer_range (`float`, *optional*, defaults to 0.02):
45
  The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
@@ -76,19 +76,19 @@ class MiniMaxText01Config(PretrainedConfig):
76
  Amount of noise to add to the router.
77
 
78
  ```python
79
- >>> from transformers import MiniMaxText01Model, MiniMaxText01Config
80
 
81
- >>> # Initializing a MiniMaxText01 style configuration
82
- >>> configuration = MiniMaxText01Config()
83
 
84
- >>> # Initializing a model from the MiniMaxText01 style configuration
85
- >>> model = MiniMaxText01Model(configuration)
86
 
87
  >>> # Accessing the model configuration
88
  >>> configuration = model.config
89
  ```"""
90
 
91
- model_type = "MiniMaxText01"
92
  keys_to_ignore_at_inference = ["past_key_values"]
93
 
94
  def __init__(
 
1
+ """ MiniMaxM1 model configuration"""
2
 
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.utils import logging
 
7
  logger = logging.get_logger(__name__)
8
 
9
 
10
+ class MiniMaxM1Config(PretrainedConfig):
11
  r"""
12
+ This is the configuration class to store the configuration of a [`MiniMaxM1Model`]. It is used to instantiate an
13
+ MiniMaxM1 model according to the specified arguments, defining the model architecture. Instantiating a configuration
14
+ with the defaults will yield a similar configuration to that of the MiniMaxM1.
15
 
16
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
17
  documentation from [`PretrainedConfig`] for more information.
 
19
 
20
  Args:
21
  vocab_size (`int`, *optional*, defaults to 32000):
22
+ Vocabulary size of the MiniMaxM1 model. Defines the number of different tokens that can be represented by the
23
+ `inputs_ids` passed when calling [`MiniMaxM1Model`]
24
  hidden_size (`int`, *optional*, defaults to 4096):
25
  Dimension of the hidden representations.
26
  intermediate_size (`int`, *optional*, defaults to 14336):
 
39
  hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
40
  The non-linear activation function (function or string) in the decoder.
41
  max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
42
+ The maximum sequence length that this model might ever be used with. MiniMaxM1's sliding window attention
43
  allows sequence of up to 4096*32 tokens.
44
  initializer_range (`float`, *optional*, defaults to 0.02):
45
  The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
 
76
  Amount of noise to add to the router.
77
 
78
  ```python
79
+ >>> from transformers import MiniMaxM1Model, MiniMaxM1Config
80
 
81
+ >>> # Initializing a MiniMaxM1 style configuration
82
+ >>> configuration = MiniMaxM1Config()
83
 
84
+ >>> # Initializing a model from the MiniMaxM1 style configuration
85
+ >>> model = MiniMaxM1Model(configuration)
86
 
87
  >>> # Accessing the model configuration
88
  >>> configuration = model.config
89
  ```"""
90
 
91
+ model_type = "MiniMaxM1"
92
  keys_to_ignore_at_inference = ["past_key_values"]
93
 
94
  def __init__(
modeling_minimax_m1.py CHANGED
@@ -1,4 +1,4 @@
1
- """ PyTorch MiniMaxText01 model."""
2
  import inspect
3
  import math
4
  import warnings
@@ -31,7 +31,7 @@ from transformers.utils import (
31
  replace_return_docstrings,
32
  )
33
  from transformers.utils.import_utils import is_torch_fx_available
34
- from .configuration_minimax_m1 import MiniMaxText01Config
35
 
36
  if is_flash_attn_2_available():
37
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -52,7 +52,7 @@ BLOCK = 256
52
 
53
  logger = logging.get_logger(__name__)
54
 
55
- _CONFIG_FOR_DOC = "MiniMaxText01Config"
56
 
57
 
58
  def get_activation_fn(activation):
@@ -207,8 +207,8 @@ class GLU(nn.Module):
207
  return output
208
 
209
 
210
- class MiniMaxText01LightningAttention(nn.Module):
211
- def __init__(self, config: MiniMaxText01Config, layer_idx: Optional[int] = None):
212
  super().__init__()
213
  bias = False
214
  self.hidden_size = config.hidden_size
@@ -217,7 +217,7 @@ class MiniMaxText01LightningAttention(nn.Module):
217
 
218
  self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
219
  self.act = get_activation_fn(config.hidden_act)
220
- self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
221
 
222
  self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
223
  self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
@@ -338,11 +338,11 @@ class MiniMaxText01LightningAttention(nn.Module):
338
  return output, attn_weights, kv
339
 
340
 
341
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
342
- class MiniMaxText01RMSNorm(nn.Module):
343
  def __init__(self, hidden_size, eps=1e-6):
344
  """
345
- MiniMaxText01RMSNorm is equivalent to T5LayerNorm
346
  """
347
  super().__init__()
348
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -356,8 +356,8 @@ class MiniMaxText01RMSNorm(nn.Module):
356
  return self.weight * hidden_states.to(input_dtype)
357
 
358
 
359
- # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->MiniMaxText01
360
- class MiniMaxText01RotaryEmbedding(nn.Module):
361
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
362
  super().__init__()
363
 
@@ -447,14 +447,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
447
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
448
 
449
 
450
- # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->MiniMaxText01
451
- class MiniMaxText01Attention(nn.Module):
452
  """
453
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
454
  and "Generating Long Sequences with Sparse Transformers".
455
  """
456
 
457
- def __init__(self, config: MiniMaxText01Config, layer_idx: Optional[int] = None):
458
  super().__init__()
459
  self.config = config
460
  self.layer_idx = layer_idx
@@ -481,7 +481,7 @@ class MiniMaxText01Attention(nn.Module):
481
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
482
  self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
483
 
484
- self.rotary_emb = MiniMaxText01RotaryEmbedding(
485
  self.rotary_dim,
486
  max_position_embeddings=self.max_position_embeddings,
487
  base=self.rope_theta,
@@ -572,10 +572,10 @@ class MiniMaxText01Attention(nn.Module):
572
  return attn_output, attn_weights, past_key_value
573
 
574
 
575
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->MiniMaxText01
576
- class MiniMaxText01FlashAttention2(MiniMaxText01Attention):
577
  """
578
- MiniMaxText01 flash attention module. This module inherits from `MiniMaxText01Attention` as the weights of the module stays
579
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
580
  flash attention and deal with padding tokens in case the input contains any of them.
581
  """
@@ -836,7 +836,7 @@ class MiniMaxText01FlashAttention2(MiniMaxText01Attention):
836
  )
837
 
838
 
839
- class MiniMaxText01MLP(nn.Module):
840
  def __init__(self, config):
841
  super().__init__()
842
  self.config = config
@@ -852,8 +852,8 @@ class MiniMaxText01MLP(nn.Module):
852
  return down_proj
853
 
854
 
855
- class MiniMaxText01BlockSparseTop2MLP(nn.Module):
856
- def __init__(self, config: MiniMaxText01Config):
857
  super().__init__()
858
  self.ffn_dim = config.intermediate_size
859
  self.hidden_dim = config.hidden_size
@@ -870,15 +870,15 @@ class MiniMaxText01BlockSparseTop2MLP(nn.Module):
870
  return current_hidden_states
871
 
872
 
873
- class MiniMaxText01BLockSparseTop2MLP(MiniMaxText01BlockSparseTop2MLP):
874
  def __init__(self, *args, **kwargs):
875
  logger.warning_once(
876
- "MiniMaxText01BLockSparseTop2MLP is deprecated by MiniMaxText01BlockSparseTop2MLP and will be removed in v4.40."
877
  )
878
  super().__init__(*args, **kwargs)
879
 
880
 
881
- class MiniMaxText01SparseMoeBlock(nn.Module):
882
  """
883
  This implementation is
884
  strictly equivalent to standard MoE with full capacity (no
@@ -900,7 +900,7 @@ class MiniMaxText01SparseMoeBlock(nn.Module):
900
  # gating
901
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
902
 
903
- self.experts = nn.ModuleList([MiniMaxText01BlockSparseTop2MLP(config) for _ in range(self.num_experts)])
904
 
905
  # Jitter parameters
906
  self.jitter_noise = config.router_jitter_noise
@@ -946,8 +946,8 @@ class MiniMaxText01SparseMoeBlock(nn.Module):
946
  return final_hidden_states, router_logits
947
 
948
 
949
- class MiniMaxText01DecoderLayer(nn.Module):
950
- def __init__(self, config: MiniMaxText01Config, layer_idx: int):
951
  super().__init__()
952
  self.config = config
953
  self.hidden_size = config.hidden_size
@@ -956,9 +956,9 @@ class MiniMaxText01DecoderLayer(nn.Module):
956
 
957
  self.layer_idx = layer_idx
958
 
959
- self.block_sparse_moe = MiniMaxText01SparseMoeBlock(config)
960
- self.input_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
961
- self.post_attention_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
962
 
963
  self.postnorm = getattr(config, 'postnorm', False)
964
  self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
@@ -972,14 +972,14 @@ class MiniMaxText01DecoderLayer(nn.Module):
972
  self.shared_moe = False
973
  if shared_intermediate > 0:
974
  self.shared_moe = True
975
- self.shared_mlp = MiniMaxText01MLP(config)
976
  self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
977
 
978
  def build_attn(self, config, layer_idx):
979
  if config.attention_type == 0:
980
- Attention_module = MiniMaxText01LightningAttention
981
  else:
982
- Attention_module = MiniMaxText01FlashAttention2
983
 
984
  return Attention_module(
985
  config,
@@ -1081,7 +1081,7 @@ MIXTRAL_START_DOCSTRING = r"""
1081
  and behavior.
1082
 
1083
  Parameters:
1084
- config ([`MiniMaxText01Config`]):
1085
  Model configuration class with all the parameters of the model. Initializing with a config file does not
1086
  load the weights associated with the model, only the configuration. Check out the
1087
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -1089,15 +1089,15 @@ MIXTRAL_START_DOCSTRING = r"""
1089
 
1090
 
1091
  @add_start_docstrings(
1092
- "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
1093
  MIXTRAL_START_DOCSTRING,
1094
  )
1095
- # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->MiniMaxText01
1096
- class MiniMaxText01PreTrainedModel(PreTrainedModel):
1097
- config_class = MiniMaxText01Config
1098
  base_model_prefix = "model"
1099
  supports_gradient_checkpointing = True
1100
- _no_split_modules = ["MiniMaxText01DecoderLayer"]
1101
  _skip_keys_device_placement = "past_key_values"
1102
  _supports_flash_attn_2 = True
1103
  _supports_sdpa = True
@@ -1182,19 +1182,19 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
1182
 
1183
 
1184
  @add_start_docstrings(
1185
- "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.",
1186
  MIXTRAL_START_DOCSTRING,
1187
  )
1188
- # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->MiniMaxText01
1189
- class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
1190
  """
1191
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxText01DecoderLayer`]
1192
 
1193
  Args:
1194
- config: MiniMaxText01Config
1195
  """
1196
 
1197
- def __init__(self, config: MiniMaxText01Config):
1198
  super().__init__(config)
1199
  self.padding_idx = config.pad_token_id
1200
  self.vocab_size = config.vocab_size
@@ -1212,10 +1212,10 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
1212
  else:
1213
  _config._attn_implementation = config_copy._attn_implementation
1214
  _config.attention_type = 1
1215
- self.layers.append(MiniMaxText01DecoderLayer(_config, i))
1216
 
1217
  self._attn_implementation = config_copy._attn_implementation
1218
- self.norm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1219
 
1220
  self.gradient_checkpointing = False
1221
  self.slopes = self._build_slope_tensor(config.num_attention_heads)
@@ -1327,7 +1327,7 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
1327
  if is_padding_right:
1328
  raise ValueError(
1329
  "You are attempting to perform batched generation with padding_side='right'"
1330
- " this may lead to unexpected behaviour for Flash Attention version of MiniMaxText01. Make sure to "
1331
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1332
  )
1333
  slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
@@ -1401,12 +1401,12 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
1401
  )
1402
 
1403
 
1404
- class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
1405
  _tied_weights_keys = ["lm_head.weight"]
1406
 
1407
  def __init__(self, config):
1408
  super().__init__(config)
1409
- self.model = MiniMaxText01Model(config)
1410
  self.vocab_size = config.vocab_size
1411
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1412
  self.router_aux_loss_coef = config.router_aux_loss_coef
@@ -1462,9 +1462,9 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
1462
  Example:
1463
 
1464
  ```python
1465
- >>> from transformers import AutoTokenizer, MiniMaxText01ForCausalLM
1466
 
1467
- >>> model = MiniMaxText01ForCausalLM.from_pretrained(PATH_TO_WEIGHTS)
1468
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
1469
 
1470
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
@@ -1579,9 +1579,9 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
1579
 
1580
  @add_start_docstrings(
1581
  """
1582
- The MiniMaxText01 Model transformer with a sequence classification head on top (linear layer).
1583
 
1584
- [`MiniMaxText01ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1585
  (e.g. GPT-2) do.
1586
 
1587
  Since it does classification on the last token, it requires to know the position of the last token. If a
@@ -1592,12 +1592,12 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
1592
  """,
1593
  MIXTRAL_START_DOCSTRING,
1594
  )
1595
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->MiniMaxText01, LLAMA->MIXTRAL
1596
- class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel):
1597
  def __init__(self, config):
1598
  super().__init__(config)
1599
  self.num_labels = config.num_labels
1600
- self.model = MiniMaxText01Model(config)
1601
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1602
 
1603
  # Initialize weights and apply final processing
 
1
+ """ PyTorch MiniMaxM1 model."""
2
  import inspect
3
  import math
4
  import warnings
 
31
  replace_return_docstrings,
32
  )
33
  from transformers.utils.import_utils import is_torch_fx_available
34
+ from .configuration_minimax_m1 import MiniMaxM1Config
35
 
36
  if is_flash_attn_2_available():
37
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
52
 
53
  logger = logging.get_logger(__name__)
54
 
55
+ _CONFIG_FOR_DOC = "MiniMaxM1Config"
56
 
57
 
58
  def get_activation_fn(activation):
 
207
  return output
208
 
209
 
210
+ class MiniMaxM1LightningAttention(nn.Module):
211
+ def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
212
  super().__init__()
213
  bias = False
214
  self.hidden_size = config.hidden_size
 
217
 
218
  self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
219
  self.act = get_activation_fn(config.hidden_act)
220
+ self.norm = MiniMaxM1RMSNorm(self.head_dim * self.num_heads)
221
 
222
  self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
223
  self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
 
338
  return output, attn_weights, kv
339
 
340
 
341
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxM1
342
+ class MiniMaxM1RMSNorm(nn.Module):
343
  def __init__(self, hidden_size, eps=1e-6):
344
  """
345
+ MiniMaxM1RMSNorm is equivalent to T5LayerNorm
346
  """
347
  super().__init__()
348
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
356
  return self.weight * hidden_states.to(input_dtype)
357
 
358
 
359
+ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->MiniMaxM1
360
+ class MiniMaxM1RotaryEmbedding(nn.Module):
361
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
362
  super().__init__()
363
 
 
447
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
448
 
449
 
450
+ # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->MiniMaxM1
451
+ class MiniMaxM1Attention(nn.Module):
452
  """
453
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
454
  and "Generating Long Sequences with Sparse Transformers".
455
  """
456
 
457
+ def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
458
  super().__init__()
459
  self.config = config
460
  self.layer_idx = layer_idx
 
481
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
482
  self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
483
 
484
+ self.rotary_emb = MiniMaxM1RotaryEmbedding(
485
  self.rotary_dim,
486
  max_position_embeddings=self.max_position_embeddings,
487
  base=self.rope_theta,
 
572
  return attn_output, attn_weights, past_key_value
573
 
574
 
575
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->MiniMaxM1
576
+ class MiniMaxM1FlashAttention2(MiniMaxM1Attention):
577
  """
578
+ MiniMaxM1 flash attention module. This module inherits from `MiniMaxM1Attention` as the weights of the module stays
579
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
580
  flash attention and deal with padding tokens in case the input contains any of them.
581
  """
 
836
  )
837
 
838
 
839
+ class MiniMaxM1MLP(nn.Module):
840
  def __init__(self, config):
841
  super().__init__()
842
  self.config = config
 
852
  return down_proj
853
 
854
 
855
+ class MiniMaxM1BlockSparseTop2MLP(nn.Module):
856
+ def __init__(self, config: MiniMaxM1Config):
857
  super().__init__()
858
  self.ffn_dim = config.intermediate_size
859
  self.hidden_dim = config.hidden_size
 
870
  return current_hidden_states
871
 
872
 
873
+ class MiniMaxM1BLockSparseTop2MLP(MiniMaxM1BlockSparseTop2MLP):
874
  def __init__(self, *args, **kwargs):
875
  logger.warning_once(
876
+ "MiniMaxM1BLockSparseTop2MLP is deprecated by MiniMaxM1BlockSparseTop2MLP and will be removed in v4.40."
877
  )
878
  super().__init__(*args, **kwargs)
879
 
880
 
881
+ class MiniMaxM1SparseMoeBlock(nn.Module):
882
  """
883
  This implementation is
884
  strictly equivalent to standard MoE with full capacity (no
 
900
  # gating
901
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
902
 
903
+ self.experts = nn.ModuleList([MiniMaxM1BlockSparseTop2MLP(config) for _ in range(self.num_experts)])
904
 
905
  # Jitter parameters
906
  self.jitter_noise = config.router_jitter_noise
 
946
  return final_hidden_states, router_logits
947
 
948
 
949
+ class MiniMaxM1DecoderLayer(nn.Module):
950
+ def __init__(self, config: MiniMaxM1Config, layer_idx: int):
951
  super().__init__()
952
  self.config = config
953
  self.hidden_size = config.hidden_size
 
956
 
957
  self.layer_idx = layer_idx
958
 
959
+ self.block_sparse_moe = MiniMaxM1SparseMoeBlock(config)
960
+ self.input_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
961
+ self.post_attention_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
962
 
963
  self.postnorm = getattr(config, 'postnorm', False)
964
  self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
 
972
  self.shared_moe = False
973
  if shared_intermediate > 0:
974
  self.shared_moe = True
975
+ self.shared_mlp = MiniMaxM1MLP(config)
976
  self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
977
 
978
  def build_attn(self, config, layer_idx):
979
  if config.attention_type == 0:
980
+ Attention_module = MiniMaxM1LightningAttention
981
  else:
982
+ Attention_module = MiniMaxM1FlashAttention2
983
 
984
  return Attention_module(
985
  config,
 
1081
  and behavior.
1082
 
1083
  Parameters:
1084
+ config ([`MiniMaxM1Config`]):
1085
  Model configuration class with all the parameters of the model. Initializing with a config file does not
1086
  load the weights associated with the model, only the configuration. Check out the
1087
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
1089
 
1090
 
1091
  @add_start_docstrings(
1092
+ "The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
1093
  MIXTRAL_START_DOCSTRING,
1094
  )
1095
+ # Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->MiniMaxM1
1096
+ class MiniMaxM1PreTrainedModel(PreTrainedModel):
1097
+ config_class = MiniMaxM1Config
1098
  base_model_prefix = "model"
1099
  supports_gradient_checkpointing = True
1100
+ _no_split_modules = ["MiniMaxM1DecoderLayer"]
1101
  _skip_keys_device_placement = "past_key_values"
1102
  _supports_flash_attn_2 = True
1103
  _supports_sdpa = True
 
1182
 
1183
 
1184
  @add_start_docstrings(
1185
+ "The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
1186
  MIXTRAL_START_DOCSTRING,
1187
  )
1188
+ # Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->MiniMaxM1
1189
+ class MiniMaxM1Model(MiniMaxM1PreTrainedModel):
1190
  """
1191
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxM1DecoderLayer`]
1192
 
1193
  Args:
1194
+ config: MiniMaxM1Config
1195
  """
1196
 
1197
+ def __init__(self, config: MiniMaxM1Config):
1198
  super().__init__(config)
1199
  self.padding_idx = config.pad_token_id
1200
  self.vocab_size = config.vocab_size
 
1212
  else:
1213
  _config._attn_implementation = config_copy._attn_implementation
1214
  _config.attention_type = 1
1215
+ self.layers.append(MiniMaxM1DecoderLayer(_config, i))
1216
 
1217
  self._attn_implementation = config_copy._attn_implementation
1218
+ self.norm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1219
 
1220
  self.gradient_checkpointing = False
1221
  self.slopes = self._build_slope_tensor(config.num_attention_heads)
 
1327
  if is_padding_right:
1328
  raise ValueError(
1329
  "You are attempting to perform batched generation with padding_side='right'"
1330
+ " this may lead to unexpected behaviour for Flash Attention version of MiniMaxM1. Make sure to "
1331
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1332
  )
1333
  slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
 
1401
  )
1402
 
1403
 
1404
+ class MiniMaxM1ForCausalLM(MiniMaxM1PreTrainedModel):
1405
  _tied_weights_keys = ["lm_head.weight"]
1406
 
1407
  def __init__(self, config):
1408
  super().__init__(config)
1409
+ self.model = MiniMaxM1Model(config)
1410
  self.vocab_size = config.vocab_size
1411
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1412
  self.router_aux_loss_coef = config.router_aux_loss_coef
 
1462
  Example:
1463
 
1464
  ```python
1465
+ >>> from transformers import AutoTokenizer, MiniMaxM1ForCausalLM
1466
 
1467
+ >>> model = MiniMaxM1ForCausalLM.from_pretrained(PATH_TO_WEIGHTS)
1468
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
1469
 
1470
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
 
1579
 
1580
  @add_start_docstrings(
1581
  """
1582
+ The MiniMaxM1 Model transformer with a sequence classification head on top (linear layer).
1583
 
1584
+ [`MiniMaxM1ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1585
  (e.g. GPT-2) do.
1586
 
1587
  Since it does classification on the last token, it requires to know the position of the last token. If a
 
1592
  """,
1593
  MIXTRAL_START_DOCSTRING,
1594
  )
1595
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->MiniMaxM1, LLAMA->MIXTRAL
1596
+ class MiniMaxM1ForSequenceClassification(MiniMaxM1PreTrainedModel):
1597
  def __init__(self, config):
1598
  super().__init__(config)
1599
  self.num_labels = config.num_labels
1600
+ self.model = MiniMaxM1Model(config)
1601
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1602
 
1603
  # Initialize weights and apply final processing