Last commit not found
""" | |
LLaDA configuration | |
""" | |
from transformers import AutoConfig, PretrainedConfig | |
from enum import Enum | |
from os import PathLike | |
from typing import Union | |
from dataclasses import asdict, dataclass, field | |
from glob import glob | |
from pathlib import Path | |
from typing import ( | |
Any, | |
Dict, | |
Iterable, | |
List, | |
Optional, | |
Tuple, | |
Type, | |
TypeVar, | |
Union, | |
cast, | |
) | |
__all__ = [ | |
"ActivationType", | |
"ActivationCheckpointingStrategy", | |
"BlockType", | |
"LayerNormType", | |
"InitFnType", | |
"ModelConfig", | |
] | |
PathOrStr = Union[str, PathLike] | |
class StrEnum(str, Enum): | |
""" | |
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11. | |
We include this here for compatibility with older version of Python. | |
""" | |
def __str__(self) -> str: | |
return self.value | |
def __repr__(self) -> str: | |
return f"'{str(self)}'" | |
class LayerNormType(StrEnum): | |
default = "default" | |
""" | |
The default LayerNorm implementation, equivalent to PyTorch's built-in version. | |
""" | |
low_precision = "low_precision" | |
""" | |
A low-precision version of the default LayerNorm. | |
""" | |
rms = "rms" | |
""" | |
An RMSNorm implementation. When using ``torch.compile`` this is | |
probably the fastest implementation. | |
""" | |
gemma_rms = "gemma_rms" | |
""" | |
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is | |
probably the fastest implementation. | |
""" | |
amd_compatible = "amd_compatible" | |
""" | |
LayerNorm implemented manually to work around an issue with ROCm. | |
""" | |
class ActivationType(StrEnum): | |
gelu = "gelu" | |
relu = "relu" | |
silu = "silu" | |
swiglu = "swiglu" | |
class BlockType(StrEnum): | |
sequential = "sequential" | |
parallel = "parallel" | |
llama = "llama" | |
""" | |
A block similar to the sequential block with slightly different | |
implementations of operations like attention to imitate the behavior of Llama. | |
""" | |
class InitFnType(StrEnum): | |
mitchell = "mitchell" | |
""" | |
The strategy suggested to us by Mitchell Wortsman from UW. | |
This uses a truncated normal distribution with an adaptive standard deviation that depends | |
on the size of the weights as well as the depth of the layer. | |
""" | |
normal = "normal" | |
""" | |
All weights are initialized from the same normal distribution. | |
""" | |
kaiming_normal = "kaiming_normal" | |
""" | |
All weights are initialized with the Kaiming method from a normal distribution. | |
Note this currently won't work with FSDP. | |
""" | |
fan_in = "fan_in" | |
""" | |
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in`` | |
is the input dimensionality of the kernel. | |
""" | |
full_megatron = "full_megatron" | |
""" | |
This is what metaseq calls "full megatron init". It is the init used for Llama 2. | |
""" | |
class ModelConfig(): | |
""" | |
LLaDA (model) configuration. | |
""" | |
# Note that the defaults for these attributes are equivalent to the base GPT2 model. | |
d_model: int = 768 | |
""" | |
The hidden size of the model. | |
""" | |
n_heads: int = 12 | |
""" | |
The number of self-attention heads. | |
""" | |
n_kv_heads: Optional[int] = None | |
""" | |
The number of heads to use for keys and values. Defaults to `n_heads`. | |
Set this to ``None`` or ``n_heads`` for normal multi-head attention. | |
Set this to 1 for multi-query attention. | |
Set it to some in-between value for Llama2-style grouped query attention. | |
""" | |
n_layers: int = 12 | |
""" | |
The number of layers/blocks. | |
""" | |
mlp_ratio: int = 4 | |
""" | |
The ratio of the inner MLP dimensionality to ``d_model``. | |
This is only used when ``mlp_hidden_size`` is not set. | |
""" | |
mlp_hidden_size: Optional[int] = None | |
""" | |
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`. | |
""" | |
activation_type: ActivationType = ActivationType.swiglu | |
""" | |
The activation function to use within the MLP layers. | |
""" | |
block_type: BlockType = BlockType.sequential | |
""" | |
The transformer block implementation. | |
""" | |
block_group_size: int = 1 | |
""" | |
The number of blocks to group together into a single parent block. | |
This has no affect on the number of parameters in the model and is only used to wrap groups | |
of blocks together with a single FSDP wrapper during training. | |
""" | |
alibi: bool = False | |
""" | |
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. | |
""" | |
alibi_bias_max: float = 8.0 | |
""" | |
Maximum absolute value of ALiBi bias. | |
""" | |
rope: bool = False | |
""" | |
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. | |
""" | |
rope_full_precision: bool = True | |
""" | |
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, | |
apply RoPE at the precision of the input. | |
""" | |
flash_attention: bool = False | |
""" | |
If ``True``, use ``FlashAttention``. | |
""" | |
attention_dropout: float = 0.1 | |
""" | |
The dropout probability within the attention modules. | |
""" | |
multi_query_attention: Optional[bool] = None | |
""" | |
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters | |
and is more efficient during inference. | |
""" | |
attention_layer_norm: bool = False | |
""" | |
Apply layer norm to the keys and queries within the attention mechanism. | |
This can help stabilize training. | |
""" | |
residual_dropout: float = 0.1 | |
""" | |
The dropout probability for the MLP and attention output within each block. | |
""" | |
embedding_dropout: float = 0.1 | |
""" | |
The dropout probability for embeddings. | |
""" | |
input_emb_norm: bool = False | |
""" | |
An input hidden_states norm implementation by gemmma. | |
""" | |
layer_norm_type: LayerNormType = LayerNormType.default | |
""" | |
The layernorm implementation to use. | |
""" | |
layer_norm_with_affine: bool = True | |
""" | |
Whether to include bias and weight parameters for the layer norms. | |
This only affects layer norms that are immediately followed by a linear layer in the forward pass, | |
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine` | |
to ``False``. | |
""" | |
rms_norm_eps: float = 1e-05 | |
""" | |
The rms layernorm eps param. | |
""" | |
attention_layer_norm_with_affine: bool = True | |
""" | |
Toggle affine transform for the QK norms. | |
""" | |
max_sequence_length: int = 1024 | |
""" | |
The maximum input sequence length supported by the model. | |
""" | |
rope_theta: float = 10000.0 | |
""" | |
The rope base param. | |
""" | |
include_qkv_bias: Optional[bool] = False | |
""" | |
Whether or not to include bias parameters in qkv linear layers. | |
""" | |
include_bias: bool = False | |
""" | |
Whether or not to include bias parameters in linear layers. | |
In PaLM, they got rid of all bias terms because they found that large | |
models tend to have near 0 bias terms anyway. | |
""" | |
bias_for_layer_norm: Optional[bool] = None | |
""" | |
Whether or not to include bias parameters in layer norm. | |
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in | |
layer norm. | |
When this is None (the default), it inherits the setting from include_bias. | |
""" | |
scale_logits: bool = False | |
""" | |
If ``True``, scale the output logits by ``1 / sqrt(d_model)``. | |
""" | |
vocab_size: int = 50257 | |
""" | |
Vocabulary size of the model. | |
""" | |
embedding_size: Optional[int] = 50304 | |
""" | |
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default | |
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the | |
next multiple of 128 that's greater than ``vocab_size`` can improve throughput | |
substantially. | |
""" | |
weight_tying: bool = True | |
""" | |
Whether to tie output linear weights to the input embedding. | |
""" | |
eos_token_id: int = 50256 | |
""" | |
The ID of the end-of-sentence special token. | |
""" | |
pad_token_id: int = 50256 | |
""" | |
The ID of the token to use for padding. Defaults to the ID of the EOS token. | |
""" | |
mask_token_id: Optional[int] = 50256 | |
""" | |
The ID of the token to use for mask token. Defaults to the ID of the EOS token. | |
""" | |
init_device: Optional[str] = None | |
""" | |
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta". | |
""" | |
init_fn: InitFnType = InitFnType.normal | |
""" | |
The weight initialization strategy. | |
""" | |
init_std: float = 0.02 | |
""" | |
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such | |
as "normal". | |
""" | |
init_cutoff_factor: Optional[float] = None | |
""" | |
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such | |
as "normal". Setting this to None means values are not cutoff. | |
""" | |
precision: Optional[str] = None | |
""" | |
Precision used to train/evaluate with. You shouldn't set this directly. | |
See :data:`TrainConfig.precision` instead. | |
""" | |
def effective_n_kv_heads(self) -> int: | |
if self.n_kv_heads is None: | |
if self.multi_query_attention is True: | |
return 1 | |
else: | |
return self.n_heads | |
else: | |
if self.multi_query_attention is None: | |
return self.n_kv_heads | |
if self.multi_query_attention: | |
n_kv_heads_should_be = 1 | |
else: | |
n_kv_heads_should_be = self.n_heads | |
if self.n_kv_heads == n_kv_heads_should_be: | |
return n_kv_heads_should_be | |
else: | |
raise Exception( | |
"You can't set `multi_query_attention` and `n_kv_heads` at the same time." | |
) | |
class ActivationCheckpointingStrategy(StrEnum): | |
whole_layer = "whole_layer" | |
""" | |
Checkpoint every transformer layer. | |
""" | |
one_in_two = "one_in_two" | |
""" | |
Checkpoint one in two transformer layers. | |
""" | |
one_in_three = "one_in_three" | |
""" | |
Checkpoint one in three transformer layers. | |
""" | |
one_in_four = "one_in_four" | |
""" | |
Checkpoint one in four transformer layers. | |
""" | |
two_in_three = "two_in_three" | |
""" | |
Checkpoint two out of every three transformer layers. | |
""" | |
three_in_four = "three_in_four" | |
""" | |
Checkpoint three out of four of every transformer layers. | |
""" | |
four_in_five = "four_in_five" | |
""" | |
Checkpoint four out of five of every transformer layers. | |
""" | |
nine_in_ten = "nine_in_ten" | |
""" | |
Checkpoint nine out of ten of every transformer layers. | |
""" | |
fine_grained = "fine_grained" | |
""" | |
Focus checkpointing on where it is cheap to recompute and saves most memory. | |
""" | |
class LLaDAConfig(PretrainedConfig): | |
model_type = "llada" | |
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm | |
def __init__(self, use_cache: bool = False, **kwargs): | |
model_config = ModelConfig() | |
all_kwargs = model_config.__dict__ | |
all_kwargs.update(kwargs) | |
all_kwargs.update({"use_cache": use_cache}) | |
all_kwargs.update( | |
{ | |
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"]) | |
} | |
) | |
super().__init__(**all_kwargs) | |
def num_attention_heads(self): | |
return self.n_heads | |
def num_hidden_layers(self): | |
return self.n_layers | |
def hidden_size(self): | |
return self.d_model | |
# Register the config class so that it is available for transformer pipelines, auto-loading etc. | |
AutoConfig.register("llada", LLaDAConfig) | |