Transformers documentation
ModernBERT Decoder
ModernBERT Decoder
ModernBERT Decoder is the same architecture as ModernBERT but trained from scratch with a causal language modeling (CLM) objective. This allows for using the same architecture for comparing encoders and decoders. This is the decoder architecture implementation of ModernBERT, designed for autoregressive text generation tasks.
Like the encoder version, ModernBERT Decoder incorporates modern architectural improvements such as rotary positional embeddings to support sequences of up to 8192 tokens, unpadding to avoid wasting compute on padding tokens, GeGLU layers, and alternating attention patterns. However, it uses causal (unidirectional) attention to enable autoregressive generation.
Click on the ModernBERT Decoder models in the right sidebar for more examples of how to apply ModernBERT Decoder to different text generation tasks.
The example below demonstrates how to use ModernBERT Decoder for text generation with Pipeline, AutoModel, and from the command line.
import torch
from transformers import pipeline
generator = pipeline(
task="text-generation",
model="blab-jhu/test-32m-dec",
torch_dtype=torch.float16,
device=0
)
generator("The future of artificial intelligence is", max_length=50, num_return_sequences=1)
# For sequence classification
classifier = pipeline(
task="text-classification",
model="blab-jhu/test-32m-dec",
torch_dtype=torch.float16,
device=0
)
classifier("This movie is really great!")
ModernBertDecoderConfig
class transformers.ModernBertDecoderConfig
< source >( vocab_size = 50368 hidden_size = 768 intermediate_size = 1152 num_hidden_layers = 22 num_attention_heads = 12 hidden_activation = 'gelu' max_position_embeddings = 8192 initializer_range = 0.02 initializer_cutoff_factor = 2.0 norm_eps = 1e-05 norm_bias = False pad_token_id = 50283 eos_token_id = 50282 bos_token_id = 50281 cls_token_id = 50281 sep_token_id = 50282 global_rope_theta = 160000.0 attention_bias = False attention_dropout = 0.0 embedding_dropout = 0.0 mlp_bias = False mlp_dropout = 0.0 decoder_bias = True classifier_dropout = 0.0 classifier_bias = False classifier_activation = 'gelu' use_cache = True local_attention = 128 global_attn_every_n_layers = 3 local_rope_theta = 160000.0 layer_types = None **kwargs )
Parameters
- vocab_size (
int
, optional, defaults to 50368) — Vocabulary size of the ModernBert decoder model. Defines the number of different tokens that can be represented by theinputs_ids
passed when calling ModernBertDecoderModel - hidden_size (
int
, optional, defaults to 768) — Dimension of the hidden representations. - intermediate_size (
int
, optional, defaults to 1152) — Dimension of the MLP representations. - num_hidden_layers (
int
, optional, defaults to 22) — Number of hidden layers in the Transformer decoder. - num_attention_heads (
int
, optional, defaults to 12) — Number of attention heads for each attention layer in the Transformer decoder. - hidden_activation (
str
orfunction
, optional, defaults to"gelu"
) — The non-linear activation function (function or string) in the decoder. Will default to"gelu"
if not specified. - max_position_embeddings (
int
, optional, defaults to 8192) — The maximum sequence length that this model might ever be used with. - initializer_range (
float
, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - initializer_cutoff_factor (
float
, optional, defaults to 2.0) — The cutoff factor for the truncated_normal_initializer for initializing all weight matrices. - norm_eps (
float
, optional, defaults to 1e-05) — The epsilon used by the rms normalization layers. - norm_bias (
bool
, optional, defaults toFalse
) — Whether to use bias in the normalization layers. - pad_token_id (
int
, optional, defaults to 50283) — Padding token id. - eos_token_id (
int
, optional, defaults to 50282) — End of stream token id. - bos_token_id (
int
, optional, defaults to 50281) — Beginning of stream token id. - cls_token_id (
int
, optional, defaults to 50281) — Classification token id. - sep_token_id (
int
, optional, defaults to 50282) — Separation token id. - global_rope_theta (
float
, optional, defaults to 160000.0) — The base period of the global RoPE embeddings. - attention_bias (
bool
, optional, defaults toFalse
) — Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (
float
, optional, defaults to 0.0) — The dropout ratio for the attention probabilities. - embedding_dropout (
float
, optional, defaults to 0.0) — The dropout ratio for the embeddings. - mlp_bias (
bool
, optional, defaults toFalse
) — Whether to use bias in the MLP layers. - mlp_dropout (
float
, optional, defaults to 0.0) — The dropout ratio for the MLP layers. - decoder_bias (
bool
, optional, defaults toTrue
) — Whether to use bias in the decoder layers. - classifier_dropout (
float
, optional, defaults to 0.0) — The dropout ratio for the classifier. - classifier_bias (
bool
, optional, defaults toFalse
) — Whether to use bias in the classifier. - classifier_activation (
str
, optional, defaults to"gelu"
) — The activation function for the classifier. - use_cache (
bool
, optional, defaults toTrue
) — Whether or not the model should return the last key/values attentions (not used by all models). Only relevant ifconfig.is_decoder=True
. - local_attention (
int
, optional, defaults to 128) — The sliding window size for local attention. Only used for layers that use local attention. Note that for the decoder to match ModernBERT this is actually half of the sliding window size, so 128 => 64. - global_attn_every_n_layers (
int
, optional, defaults to 3) — Everyglobal_attn_every_n_layers
layers will use global attention instead of local attention. - local_rope_theta (
float
, optional, defaults to 160000.0) — The base period of the local RoPE embeddings. If not specified, defaults to 160000.0 - layer_types (
list
, optional) — List of layer types, one for each layer. If not specified, will be automatically generated based onglobal_attn_every_n_layers
. Should contain “full_attention” or “sliding_attention”.
This is the configuration class to store the configuration of a ModernBertDecoderModel. It is used to instantiate a ModernBert decoder model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the ModernBERT-base decoder. e.g. blab-jhu/test-32m-dec
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Examples:
>>> from transformers import ModernBertDecoderModel, ModernBertDecoderConfig
>>> # Initializing a ModernBert decoder style configuration
>>> configuration = ModernBertDecoderConfig()
>>> # Initializing a model from the modernbert-base decoder style configuration
>>> model = ModernBertDecoderModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
ModernBertDecoderModel
class transformers.ModernBertDecoderModel
< source >( config: ModernBertDecoderConfig )
Parameters
- config (ModernBertDecoderConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The bare Modernbert Decoder Model outputting raw hidden-states without any specific head on top.
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.Tensor] = None use_cache: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None **kwargs ) → transformers.modeling_outputs.BaseModelOutputWithPast or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- attention_mask (
torch.Tensor
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
- position_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.n_positions - 1]
. - past_key_values (
~cache_utils.Cache
, optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in thepast_key_values
returned by the model at a previous stage of decoding, whenuse_cache=True
orconfig.use_cache=True
.Only Cache instance is allowed as input, see our kv cache guide. If no
past_key_values
are passed, DynamicCache will be initialized by default.The model will output the same cache format that is fed as input.
If
past_key_values
are used, the user is expected to input only unprocessedinput_ids
(those that don’t have their past key value states given to this model) of shape(batch_size, unprocessed_length)
instead of allinput_ids
of shape(batch_size, sequence_length)
. - inputs_embeds (
torch.Tensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - use_cache (
bool
, optional) — If set toTrue
,past_key_values
key value states are returned and can be used to speed up decoding (seepast_key_values
). - cache_position (
torch.LongTensor
of shape(sequence_length)
, optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily toposition_ids
, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.
Returns
transformers.modeling_outputs.BaseModelOutputWithPast or tuple(torch.FloatTensor)
A transformers.modeling_outputs.BaseModelOutputWithPast or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (ModernBertDecoderConfig) and inputs.
-
last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model.If
past_key_values
is used only the last hidden-state of the sequences of shape(batch_size, 1, hidden_size)
is output. -
past_key_values (
Cache
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — It is a Cache instance. For more details, see our kv cache guide.Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
config.is_encoder_decoder=True
in the cross-attention blocks) that can be used (seepast_key_values
input) to speed up sequential decoding. -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
-
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
The ModernBertDecoderModel forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
ModernBertDecoderForCausalLM
class transformers.ModernBertDecoderForCausalLM
< source >( config: ModernBertDecoderConfig )
Parameters
- config (ModernBertDecoderConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The ModernBert Decoder Model with a language modeling head on top for causal language modeling (CLM).
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None **kwargs )
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- attention_mask (
torch.Tensor
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
- position_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.n_positions - 1]
. - past_key_values (
~cache_utils.Cache
, optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in thepast_key_values
returned by the model at a previous stage of decoding, whenuse_cache=True
orconfig.use_cache=True
.Only Cache instance is allowed as input, see our kv cache guide. If no
past_key_values
are passed, DynamicCache will be initialized by default.The model will output the same cache format that is fed as input.
If
past_key_values
are used, the user is expected to input only unprocessedinput_ids
(those that don’t have their past key value states given to this model) of shape(batch_size, unprocessed_length)
instead of allinput_ids
of shape(batch_size, sequence_length)
. - inputs_embeds (
torch.Tensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - labels (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Labels for computing the masked language modeling loss. Indices should either be in[0, ..., config.vocab_size]
or -100 (seeinput_ids
docstring). Tokens with indices set to-100
are ignored (masked), the loss is only computed for the tokens with labels in[0, ..., config.vocab_size]
. - use_cache (
bool
, optional) — If set toTrue
,past_key_values
key value states are returned and can be used to speed up decoding (seepast_key_values
).
The ModernBertDecoderForCausalLM forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> from transformers import AutoTokenizer, ModernBertDecoderForCausalLM
>>> model = ModernBertDecoderForCausalLM.from_pretrained("blab-jhu/test-32m-dec")
>>> tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
>>> prompt = "The capital of France is"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=1)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The capital of France is Paris"
ModernBertDecoderForSequenceClassification
class transformers.ModernBertDecoderForSequenceClassification
< source >( config: ModernBertDecoderConfig )
Parameters
- config (ModernBertDecoderConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The ModernBert Decoder Model with a sequence classification head on top (linear layer).
ModernBertDecoderForSequenceClassification uses the last token in order to do the classification, as other causal models (e.g. GPT-1, GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
pad_token_id
is defined in the configuration, it finds the last token that is not a padding token in each row. If
no pad_token_id
is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when inputs_embeds
are passed instead of input_ids
, it does the same (take the last value in
each row of the batch).
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[transformers.cache_utils.Cache] = None inputs_embeds: typing.Optional[torch.Tensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None **kwargs ) → transformers.modeling_outputs.SequenceClassifierOutputWithPast
or tuple(torch.FloatTensor)
Parameters
- input_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.
- attention_mask (
torch.Tensor
of shape(batch_size, sequence_length)
, optional) — Mask to avoid performing attention on padding token indices. Mask values selected in[0, 1]
:- 1 for tokens that are not masked,
- 0 for tokens that are masked.
- position_ids (
torch.LongTensor
of shape(batch_size, sequence_length)
, optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range[0, config.n_positions - 1]
. - past_key_values (
~cache_utils.Cache
, optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in thepast_key_values
returned by the model at a previous stage of decoding, whenuse_cache=True
orconfig.use_cache=True
.Only Cache instance is allowed as input, see our kv cache guide. If no
past_key_values
are passed, DynamicCache will be initialized by default.The model will output the same cache format that is fed as input.
If
past_key_values
are used, the user is expected to input only unprocessedinput_ids
(those that don’t have their past key value states given to this model) of shape(batch_size, unprocessed_length)
instead of allinput_ids
of shape(batch_size, sequence_length)
. - inputs_embeds (
torch.Tensor
of shape(batch_size, sequence_length, hidden_size)
, optional) — Optionally, instead of passinginput_ids
you can choose to directly pass an embedded representation. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model’s internal embedding lookup matrix. - labels (
torch.LongTensor
of shape(batch_size,)
, optional) — Labels for computing the sequence classification/regression loss. Indices should be in[0, ..., config.num_labels - 1]
. Ifconfig.num_labels == 1
a regression loss is computed (Mean-Square loss), Ifconfig.num_labels > 1
a classification loss is computed (Cross-Entropy). - use_cache (
bool
, optional) — If set toTrue
,past_key_values
key value states are returned and can be used to speed up decoding (seepast_key_values
).
Returns
transformers.modeling_outputs.SequenceClassifierOutputWithPast
or tuple(torch.FloatTensor)
A transformers.modeling_outputs.SequenceClassifierOutputWithPast
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (ModernBertDecoderConfig) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) — Classification (or regression if config.num_labels==1) loss. -
logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) — Classification (or regression if config.num_labels==1) scores (before SoftMax). -
past_key_values (
Cache
, optional, returned whenuse_cache=True
is passed or whenconfig.use_cache=True
) — It is a Cache instance. For more details, see our kv cache guide.Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
past_key_values
input) to speed up sequential decoding. -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
-
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
The ModernBertDecoderForSequenceClassification forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example of single-label classification:
>>> import torch
>>> from transformers import AutoTokenizer, ModernBertDecoderForSequenceClassification
>>> tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
>>> model = ModernBertDecoderForSequenceClassification.from_pretrained("blab-jhu/test-32m-dec")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_id = logits.argmax().item()
>>> model.config.id2label[predicted_class_id]
...
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
>>> model = ModernBertDecoderForSequenceClassification.from_pretrained("blab-jhu/test-32m-dec", num_labels=num_labels)
>>> labels = torch.tensor([1])
>>> loss = model(**inputs, labels=labels).loss
>>> round(loss.item(), 2)
...
Example of multi-label classification:
>>> import torch
>>> from transformers import AutoTokenizer, ModernBertDecoderForSequenceClassification
>>> tokenizer = AutoTokenizer.from_pretrained("blab-jhu/test-32m-dec")
>>> model = ModernBertDecoderForSequenceClassification.from_pretrained("blab-jhu/test-32m-dec", problem_type="multi_label_classification")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
>>> model = ModernBertDecoderForSequenceClassification.from_pretrained(
... "blab-jhu/test-32m-dec", num_labels=num_labels, problem_type="multi_label_classification"
... )
>>> labels = torch.sum(
... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
... ).to(torch.float)
>>> loss = model(**inputs, labels=labels).loss
Usage tips
The ModernBertDecoder model can be fine-tuned for various text generation tasks using the HuggingFace Transformers library. It supports efficient inference with features like:
- Causal attention: Ensures autoregressive generation by masking future tokens
- Sliding window attention: Alternates between local and global attention patterns for efficiency
- Rotary positional embeddings: Enables handling of longer sequences up to 8000 tokens
- FlashAttention support: Optimized attention computation for faster training and inference