fix code
Browse files- configuration_minimax_m1.py +14 -14
- modeling_minimax_m1.py +57 -57
configuration_minimax_m1.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
"""
|
2 |
|
3 |
from transformers.configuration_utils import PretrainedConfig
|
4 |
from transformers.utils import logging
|
@@ -7,11 +7,11 @@ from transformers.utils import logging
|
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
9 |
|
10 |
-
class
|
11 |
r"""
|
12 |
-
This is the configuration class to store the configuration of a [`
|
13 |
-
|
14 |
-
with the defaults will yield a similar configuration to that of the
|
15 |
|
16 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
17 |
documentation from [`PretrainedConfig`] for more information.
|
@@ -19,8 +19,8 @@ class MiniMaxText01Config(PretrainedConfig):
|
|
19 |
|
20 |
Args:
|
21 |
vocab_size (`int`, *optional*, defaults to 32000):
|
22 |
-
Vocabulary size of the
|
23 |
-
`inputs_ids` passed when calling [`
|
24 |
hidden_size (`int`, *optional*, defaults to 4096):
|
25 |
Dimension of the hidden representations.
|
26 |
intermediate_size (`int`, *optional*, defaults to 14336):
|
@@ -39,7 +39,7 @@ class MiniMaxText01Config(PretrainedConfig):
|
|
39 |
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
40 |
The non-linear activation function (function or string) in the decoder.
|
41 |
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
42 |
-
The maximum sequence length that this model might ever be used with.
|
43 |
allows sequence of up to 4096*32 tokens.
|
44 |
initializer_range (`float`, *optional*, defaults to 0.02):
|
45 |
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
@@ -76,19 +76,19 @@ class MiniMaxText01Config(PretrainedConfig):
|
|
76 |
Amount of noise to add to the router.
|
77 |
|
78 |
```python
|
79 |
-
>>> from transformers import
|
80 |
|
81 |
-
>>> # Initializing a
|
82 |
-
>>> configuration =
|
83 |
|
84 |
-
>>> # Initializing a model from the
|
85 |
-
>>> model =
|
86 |
|
87 |
>>> # Accessing the model configuration
|
88 |
>>> configuration = model.config
|
89 |
```"""
|
90 |
|
91 |
-
model_type = "
|
92 |
keys_to_ignore_at_inference = ["past_key_values"]
|
93 |
|
94 |
def __init__(
|
|
|
1 |
+
""" MiniMaxM1 model configuration"""
|
2 |
|
3 |
from transformers.configuration_utils import PretrainedConfig
|
4 |
from transformers.utils import logging
|
|
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
9 |
|
10 |
+
class MiniMaxM1Config(PretrainedConfig):
|
11 |
r"""
|
12 |
+
This is the configuration class to store the configuration of a [`MiniMaxM1Model`]. It is used to instantiate an
|
13 |
+
MiniMaxM1 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
14 |
+
with the defaults will yield a similar configuration to that of the MiniMaxM1.
|
15 |
|
16 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
17 |
documentation from [`PretrainedConfig`] for more information.
|
|
|
19 |
|
20 |
Args:
|
21 |
vocab_size (`int`, *optional*, defaults to 32000):
|
22 |
+
Vocabulary size of the MiniMaxM1 model. Defines the number of different tokens that can be represented by the
|
23 |
+
`inputs_ids` passed when calling [`MiniMaxM1Model`]
|
24 |
hidden_size (`int`, *optional*, defaults to 4096):
|
25 |
Dimension of the hidden representations.
|
26 |
intermediate_size (`int`, *optional*, defaults to 14336):
|
|
|
39 |
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
40 |
The non-linear activation function (function or string) in the decoder.
|
41 |
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
42 |
+
The maximum sequence length that this model might ever be used with. MiniMaxM1's sliding window attention
|
43 |
allows sequence of up to 4096*32 tokens.
|
44 |
initializer_range (`float`, *optional*, defaults to 0.02):
|
45 |
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
|
76 |
Amount of noise to add to the router.
|
77 |
|
78 |
```python
|
79 |
+
>>> from transformers import MiniMaxM1Model, MiniMaxM1Config
|
80 |
|
81 |
+
>>> # Initializing a MiniMaxM1 style configuration
|
82 |
+
>>> configuration = MiniMaxM1Config()
|
83 |
|
84 |
+
>>> # Initializing a model from the MiniMaxM1 style configuration
|
85 |
+
>>> model = MiniMaxM1Model(configuration)
|
86 |
|
87 |
>>> # Accessing the model configuration
|
88 |
>>> configuration = model.config
|
89 |
```"""
|
90 |
|
91 |
+
model_type = "MiniMaxM1"
|
92 |
keys_to_ignore_at_inference = ["past_key_values"]
|
93 |
|
94 |
def __init__(
|
modeling_minimax_m1.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
""" PyTorch
|
2 |
import inspect
|
3 |
import math
|
4 |
import warnings
|
@@ -31,7 +31,7 @@ from transformers.utils import (
|
|
31 |
replace_return_docstrings,
|
32 |
)
|
33 |
from transformers.utils.import_utils import is_torch_fx_available
|
34 |
-
from .configuration_minimax_m1 import
|
35 |
|
36 |
if is_flash_attn_2_available():
|
37 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
@@ -52,7 +52,7 @@ BLOCK = 256
|
|
52 |
|
53 |
logger = logging.get_logger(__name__)
|
54 |
|
55 |
-
_CONFIG_FOR_DOC = "
|
56 |
|
57 |
|
58 |
def get_activation_fn(activation):
|
@@ -207,8 +207,8 @@ class GLU(nn.Module):
|
|
207 |
return output
|
208 |
|
209 |
|
210 |
-
class
|
211 |
-
def __init__(self, config:
|
212 |
super().__init__()
|
213 |
bias = False
|
214 |
self.hidden_size = config.hidden_size
|
@@ -217,7 +217,7 @@ class MiniMaxText01LightningAttention(nn.Module):
|
|
217 |
|
218 |
self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
|
219 |
self.act = get_activation_fn(config.hidden_act)
|
220 |
-
self.norm =
|
221 |
|
222 |
self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
|
223 |
self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
|
@@ -338,11 +338,11 @@ class MiniMaxText01LightningAttention(nn.Module):
|
|
338 |
return output, attn_weights, kv
|
339 |
|
340 |
|
341 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->
|
342 |
-
class
|
343 |
def __init__(self, hidden_size, eps=1e-6):
|
344 |
"""
|
345 |
-
|
346 |
"""
|
347 |
super().__init__()
|
348 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
@@ -356,8 +356,8 @@ class MiniMaxText01RMSNorm(nn.Module):
|
|
356 |
return self.weight * hidden_states.to(input_dtype)
|
357 |
|
358 |
|
359 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->
|
360 |
-
class
|
361 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
362 |
super().__init__()
|
363 |
|
@@ -447,14 +447,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
447 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
448 |
|
449 |
|
450 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->
|
451 |
-
class
|
452 |
"""
|
453 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
454 |
and "Generating Long Sequences with Sparse Transformers".
|
455 |
"""
|
456 |
|
457 |
-
def __init__(self, config:
|
458 |
super().__init__()
|
459 |
self.config = config
|
460 |
self.layer_idx = layer_idx
|
@@ -481,7 +481,7 @@ class MiniMaxText01Attention(nn.Module):
|
|
481 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
482 |
self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
|
483 |
|
484 |
-
self.rotary_emb =
|
485 |
self.rotary_dim,
|
486 |
max_position_embeddings=self.max_position_embeddings,
|
487 |
base=self.rope_theta,
|
@@ -572,10 +572,10 @@ class MiniMaxText01Attention(nn.Module):
|
|
572 |
return attn_output, attn_weights, past_key_value
|
573 |
|
574 |
|
575 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->
|
576 |
-
class
|
577 |
"""
|
578 |
-
|
579 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
580 |
flash attention and deal with padding tokens in case the input contains any of them.
|
581 |
"""
|
@@ -836,7 +836,7 @@ class MiniMaxText01FlashAttention2(MiniMaxText01Attention):
|
|
836 |
)
|
837 |
|
838 |
|
839 |
-
class
|
840 |
def __init__(self, config):
|
841 |
super().__init__()
|
842 |
self.config = config
|
@@ -852,8 +852,8 @@ class MiniMaxText01MLP(nn.Module):
|
|
852 |
return down_proj
|
853 |
|
854 |
|
855 |
-
class
|
856 |
-
def __init__(self, config:
|
857 |
super().__init__()
|
858 |
self.ffn_dim = config.intermediate_size
|
859 |
self.hidden_dim = config.hidden_size
|
@@ -870,15 +870,15 @@ class MiniMaxText01BlockSparseTop2MLP(nn.Module):
|
|
870 |
return current_hidden_states
|
871 |
|
872 |
|
873 |
-
class
|
874 |
def __init__(self, *args, **kwargs):
|
875 |
logger.warning_once(
|
876 |
-
"
|
877 |
)
|
878 |
super().__init__(*args, **kwargs)
|
879 |
|
880 |
|
881 |
-
class
|
882 |
"""
|
883 |
This implementation is
|
884 |
strictly equivalent to standard MoE with full capacity (no
|
@@ -900,7 +900,7 @@ class MiniMaxText01SparseMoeBlock(nn.Module):
|
|
900 |
# gating
|
901 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
902 |
|
903 |
-
self.experts = nn.ModuleList([
|
904 |
|
905 |
# Jitter parameters
|
906 |
self.jitter_noise = config.router_jitter_noise
|
@@ -946,8 +946,8 @@ class MiniMaxText01SparseMoeBlock(nn.Module):
|
|
946 |
return final_hidden_states, router_logits
|
947 |
|
948 |
|
949 |
-
class
|
950 |
-
def __init__(self, config:
|
951 |
super().__init__()
|
952 |
self.config = config
|
953 |
self.hidden_size = config.hidden_size
|
@@ -956,9 +956,9 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|
956 |
|
957 |
self.layer_idx = layer_idx
|
958 |
|
959 |
-
self.block_sparse_moe =
|
960 |
-
self.input_layernorm =
|
961 |
-
self.post_attention_layernorm =
|
962 |
|
963 |
self.postnorm = getattr(config, 'postnorm', False)
|
964 |
self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
|
@@ -972,14 +972,14 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|
972 |
self.shared_moe = False
|
973 |
if shared_intermediate > 0:
|
974 |
self.shared_moe = True
|
975 |
-
self.shared_mlp =
|
976 |
self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
|
977 |
|
978 |
def build_attn(self, config, layer_idx):
|
979 |
if config.attention_type == 0:
|
980 |
-
Attention_module =
|
981 |
else:
|
982 |
-
Attention_module =
|
983 |
|
984 |
return Attention_module(
|
985 |
config,
|
@@ -1081,7 +1081,7 @@ MIXTRAL_START_DOCSTRING = r"""
|
|
1081 |
and behavior.
|
1082 |
|
1083 |
Parameters:
|
1084 |
-
config ([`
|
1085 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1086 |
load the weights associated with the model, only the configuration. Check out the
|
1087 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
@@ -1089,15 +1089,15 @@ MIXTRAL_START_DOCSTRING = r"""
|
|
1089 |
|
1090 |
|
1091 |
@add_start_docstrings(
|
1092 |
-
"The bare
|
1093 |
MIXTRAL_START_DOCSTRING,
|
1094 |
)
|
1095 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->
|
1096 |
-
class
|
1097 |
-
config_class =
|
1098 |
base_model_prefix = "model"
|
1099 |
supports_gradient_checkpointing = True
|
1100 |
-
_no_split_modules = ["
|
1101 |
_skip_keys_device_placement = "past_key_values"
|
1102 |
_supports_flash_attn_2 = True
|
1103 |
_supports_sdpa = True
|
@@ -1182,19 +1182,19 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
|
|
1182 |
|
1183 |
|
1184 |
@add_start_docstrings(
|
1185 |
-
"The bare
|
1186 |
MIXTRAL_START_DOCSTRING,
|
1187 |
)
|
1188 |
-
# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->
|
1189 |
-
class
|
1190 |
"""
|
1191 |
-
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`
|
1192 |
|
1193 |
Args:
|
1194 |
-
config:
|
1195 |
"""
|
1196 |
|
1197 |
-
def __init__(self, config:
|
1198 |
super().__init__(config)
|
1199 |
self.padding_idx = config.pad_token_id
|
1200 |
self.vocab_size = config.vocab_size
|
@@ -1212,10 +1212,10 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
|
1212 |
else:
|
1213 |
_config._attn_implementation = config_copy._attn_implementation
|
1214 |
_config.attention_type = 1
|
1215 |
-
self.layers.append(
|
1216 |
|
1217 |
self._attn_implementation = config_copy._attn_implementation
|
1218 |
-
self.norm =
|
1219 |
|
1220 |
self.gradient_checkpointing = False
|
1221 |
self.slopes = self._build_slope_tensor(config.num_attention_heads)
|
@@ -1327,7 +1327,7 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
|
1327 |
if is_padding_right:
|
1328 |
raise ValueError(
|
1329 |
"You are attempting to perform batched generation with padding_side='right'"
|
1330 |
-
" this may lead to unexpected behaviour for Flash Attention version of
|
1331 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1332 |
)
|
1333 |
slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
|
@@ -1401,12 +1401,12 @@ class MiniMaxText01Model(MiniMaxText01PreTrainedModel):
|
|
1401 |
)
|
1402 |
|
1403 |
|
1404 |
-
class
|
1405 |
_tied_weights_keys = ["lm_head.weight"]
|
1406 |
|
1407 |
def __init__(self, config):
|
1408 |
super().__init__(config)
|
1409 |
-
self.model =
|
1410 |
self.vocab_size = config.vocab_size
|
1411 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1412 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
@@ -1462,9 +1462,9 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
|
|
1462 |
Example:
|
1463 |
|
1464 |
```python
|
1465 |
-
>>> from transformers import AutoTokenizer,
|
1466 |
|
1467 |
-
>>> model =
|
1468 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
|
1469 |
|
1470 |
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
@@ -1579,9 +1579,9 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
|
|
1579 |
|
1580 |
@add_start_docstrings(
|
1581 |
"""
|
1582 |
-
The
|
1583 |
|
1584 |
-
[`
|
1585 |
(e.g. GPT-2) do.
|
1586 |
|
1587 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
@@ -1592,12 +1592,12 @@ class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel):
|
|
1592 |
""",
|
1593 |
MIXTRAL_START_DOCSTRING,
|
1594 |
)
|
1595 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->
|
1596 |
-
class
|
1597 |
def __init__(self, config):
|
1598 |
super().__init__(config)
|
1599 |
self.num_labels = config.num_labels
|
1600 |
-
self.model =
|
1601 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1602 |
|
1603 |
# Initialize weights and apply final processing
|
|
|
1 |
+
""" PyTorch MiniMaxM1 model."""
|
2 |
import inspect
|
3 |
import math
|
4 |
import warnings
|
|
|
31 |
replace_return_docstrings,
|
32 |
)
|
33 |
from transformers.utils.import_utils import is_torch_fx_available
|
34 |
+
from .configuration_minimax_m1 import MiniMaxM1Config
|
35 |
|
36 |
if is_flash_attn_2_available():
|
37 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
52 |
|
53 |
logger = logging.get_logger(__name__)
|
54 |
|
55 |
+
_CONFIG_FOR_DOC = "MiniMaxM1Config"
|
56 |
|
57 |
|
58 |
def get_activation_fn(activation):
|
|
|
207 |
return output
|
208 |
|
209 |
|
210 |
+
class MiniMaxM1LightningAttention(nn.Module):
|
211 |
+
def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
|
212 |
super().__init__()
|
213 |
bias = False
|
214 |
self.hidden_size = config.hidden_size
|
|
|
217 |
|
218 |
self.out_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=bias)
|
219 |
self.act = get_activation_fn(config.hidden_act)
|
220 |
+
self.norm = MiniMaxM1RMSNorm(self.head_dim * self.num_heads)
|
221 |
|
222 |
self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias)
|
223 |
self.output_gate = nn.Linear(self.hidden_size, self.head_dim * self.num_heads, bias=bias)
|
|
|
338 |
return output, attn_weights, kv
|
339 |
|
340 |
|
341 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxM1
|
342 |
+
class MiniMaxM1RMSNorm(nn.Module):
|
343 |
def __init__(self, hidden_size, eps=1e-6):
|
344 |
"""
|
345 |
+
MiniMaxM1RMSNorm is equivalent to T5LayerNorm
|
346 |
"""
|
347 |
super().__init__()
|
348 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
356 |
return self.weight * hidden_states.to(input_dtype)
|
357 |
|
358 |
|
359 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->MiniMaxM1
|
360 |
+
class MiniMaxM1RotaryEmbedding(nn.Module):
|
361 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
362 |
super().__init__()
|
363 |
|
|
|
447 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
448 |
|
449 |
|
450 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->MiniMaxM1
|
451 |
+
class MiniMaxM1Attention(nn.Module):
|
452 |
"""
|
453 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
454 |
and "Generating Long Sequences with Sparse Transformers".
|
455 |
"""
|
456 |
|
457 |
+
def __init__(self, config: MiniMaxM1Config, layer_idx: Optional[int] = None):
|
458 |
super().__init__()
|
459 |
self.config = config
|
460 |
self.layer_idx = layer_idx
|
|
|
481 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
482 |
self.rotary_dim = getattr(config, 'rotary_dim', self.head_dim)
|
483 |
|
484 |
+
self.rotary_emb = MiniMaxM1RotaryEmbedding(
|
485 |
self.rotary_dim,
|
486 |
max_position_embeddings=self.max_position_embeddings,
|
487 |
base=self.rope_theta,
|
|
|
572 |
return attn_output, attn_weights, past_key_value
|
573 |
|
574 |
|
575 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->MiniMaxM1
|
576 |
+
class MiniMaxM1FlashAttention2(MiniMaxM1Attention):
|
577 |
"""
|
578 |
+
MiniMaxM1 flash attention module. This module inherits from `MiniMaxM1Attention` as the weights of the module stays
|
579 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
580 |
flash attention and deal with padding tokens in case the input contains any of them.
|
581 |
"""
|
|
|
836 |
)
|
837 |
|
838 |
|
839 |
+
class MiniMaxM1MLP(nn.Module):
|
840 |
def __init__(self, config):
|
841 |
super().__init__()
|
842 |
self.config = config
|
|
|
852 |
return down_proj
|
853 |
|
854 |
|
855 |
+
class MiniMaxM1BlockSparseTop2MLP(nn.Module):
|
856 |
+
def __init__(self, config: MiniMaxM1Config):
|
857 |
super().__init__()
|
858 |
self.ffn_dim = config.intermediate_size
|
859 |
self.hidden_dim = config.hidden_size
|
|
|
870 |
return current_hidden_states
|
871 |
|
872 |
|
873 |
+
class MiniMaxM1BLockSparseTop2MLP(MiniMaxM1BlockSparseTop2MLP):
|
874 |
def __init__(self, *args, **kwargs):
|
875 |
logger.warning_once(
|
876 |
+
"MiniMaxM1BLockSparseTop2MLP is deprecated by MiniMaxM1BlockSparseTop2MLP and will be removed in v4.40."
|
877 |
)
|
878 |
super().__init__(*args, **kwargs)
|
879 |
|
880 |
|
881 |
+
class MiniMaxM1SparseMoeBlock(nn.Module):
|
882 |
"""
|
883 |
This implementation is
|
884 |
strictly equivalent to standard MoE with full capacity (no
|
|
|
900 |
# gating
|
901 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
902 |
|
903 |
+
self.experts = nn.ModuleList([MiniMaxM1BlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
904 |
|
905 |
# Jitter parameters
|
906 |
self.jitter_noise = config.router_jitter_noise
|
|
|
946 |
return final_hidden_states, router_logits
|
947 |
|
948 |
|
949 |
+
class MiniMaxM1DecoderLayer(nn.Module):
|
950 |
+
def __init__(self, config: MiniMaxM1Config, layer_idx: int):
|
951 |
super().__init__()
|
952 |
self.config = config
|
953 |
self.hidden_size = config.hidden_size
|
|
|
956 |
|
957 |
self.layer_idx = layer_idx
|
958 |
|
959 |
+
self.block_sparse_moe = MiniMaxM1SparseMoeBlock(config)
|
960 |
+
self.input_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
961 |
+
self.post_attention_layernorm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
962 |
|
963 |
self.postnorm = getattr(config, 'postnorm', False)
|
964 |
self.layernorm_attention_alpha = getattr(config, 'layernorm_linear_attention_alpha', 1) \
|
|
|
972 |
self.shared_moe = False
|
973 |
if shared_intermediate > 0:
|
974 |
self.shared_moe = True
|
975 |
+
self.shared_mlp = MiniMaxM1MLP(config)
|
976 |
self.coefficient = torch.nn.Linear(self.hidden_size, 1, bias=False)
|
977 |
|
978 |
def build_attn(self, config, layer_idx):
|
979 |
if config.attention_type == 0:
|
980 |
+
Attention_module = MiniMaxM1LightningAttention
|
981 |
else:
|
982 |
+
Attention_module = MiniMaxM1FlashAttention2
|
983 |
|
984 |
return Attention_module(
|
985 |
config,
|
|
|
1081 |
and behavior.
|
1082 |
|
1083 |
Parameters:
|
1084 |
+
config ([`MiniMaxM1Config`]):
|
1085 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1086 |
load the weights associated with the model, only the configuration. Check out the
|
1087 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
|
1089 |
|
1090 |
|
1091 |
@add_start_docstrings(
|
1092 |
+
"The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
|
1093 |
MIXTRAL_START_DOCSTRING,
|
1094 |
)
|
1095 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->MiniMaxM1
|
1096 |
+
class MiniMaxM1PreTrainedModel(PreTrainedModel):
|
1097 |
+
config_class = MiniMaxM1Config
|
1098 |
base_model_prefix = "model"
|
1099 |
supports_gradient_checkpointing = True
|
1100 |
+
_no_split_modules = ["MiniMaxM1DecoderLayer"]
|
1101 |
_skip_keys_device_placement = "past_key_values"
|
1102 |
_supports_flash_attn_2 = True
|
1103 |
_supports_sdpa = True
|
|
|
1182 |
|
1183 |
|
1184 |
@add_start_docstrings(
|
1185 |
+
"The bare MiniMaxM1 Model outputting raw hidden-states without any specific head on top.",
|
1186 |
MIXTRAL_START_DOCSTRING,
|
1187 |
)
|
1188 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->MiniMaxM1
|
1189 |
+
class MiniMaxM1Model(MiniMaxM1PreTrainedModel):
|
1190 |
"""
|
1191 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxM1DecoderLayer`]
|
1192 |
|
1193 |
Args:
|
1194 |
+
config: MiniMaxM1Config
|
1195 |
"""
|
1196 |
|
1197 |
+
def __init__(self, config: MiniMaxM1Config):
|
1198 |
super().__init__(config)
|
1199 |
self.padding_idx = config.pad_token_id
|
1200 |
self.vocab_size = config.vocab_size
|
|
|
1212 |
else:
|
1213 |
_config._attn_implementation = config_copy._attn_implementation
|
1214 |
_config.attention_type = 1
|
1215 |
+
self.layers.append(MiniMaxM1DecoderLayer(_config, i))
|
1216 |
|
1217 |
self._attn_implementation = config_copy._attn_implementation
|
1218 |
+
self.norm = MiniMaxM1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1219 |
|
1220 |
self.gradient_checkpointing = False
|
1221 |
self.slopes = self._build_slope_tensor(config.num_attention_heads)
|
|
|
1327 |
if is_padding_right:
|
1328 |
raise ValueError(
|
1329 |
"You are attempting to perform batched generation with padding_side='right'"
|
1330 |
+
" this may lead to unexpected behaviour for Flash Attention version of MiniMaxM1. Make sure to "
|
1331 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1332 |
)
|
1333 |
slope_rates = [self.slopes.to(default_device) for _ in range(len(self.layers))]
|
|
|
1401 |
)
|
1402 |
|
1403 |
|
1404 |
+
class MiniMaxM1ForCausalLM(MiniMaxM1PreTrainedModel):
|
1405 |
_tied_weights_keys = ["lm_head.weight"]
|
1406 |
|
1407 |
def __init__(self, config):
|
1408 |
super().__init__(config)
|
1409 |
+
self.model = MiniMaxM1Model(config)
|
1410 |
self.vocab_size = config.vocab_size
|
1411 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1412 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
|
|
1462 |
Example:
|
1463 |
|
1464 |
```python
|
1465 |
+
>>> from transformers import AutoTokenizer, MiniMaxM1ForCausalLM
|
1466 |
|
1467 |
+
>>> model = MiniMaxM1ForCausalLM.from_pretrained(PATH_TO_WEIGHTS)
|
1468 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_WEIGHTS)
|
1469 |
|
1470 |
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
|
1579 |
|
1580 |
@add_start_docstrings(
|
1581 |
"""
|
1582 |
+
The MiniMaxM1 Model transformer with a sequence classification head on top (linear layer).
|
1583 |
|
1584 |
+
[`MiniMaxM1ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1585 |
(e.g. GPT-2) do.
|
1586 |
|
1587 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
|
1592 |
""",
|
1593 |
MIXTRAL_START_DOCSTRING,
|
1594 |
)
|
1595 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->MiniMaxM1, LLAMA->MIXTRAL
|
1596 |
+
class MiniMaxM1ForSequenceClassification(MiniMaxM1PreTrainedModel):
|
1597 |
def __init__(self, config):
|
1598 |
super().__init__(config)
|
1599 |
self.num_labels = config.num_labels
|
1600 |
+
self.model = MiniMaxM1Model(config)
|
1601 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1602 |
|
1603 |
# Initialize weights and apply final processing
|