Mamba-790M / configuration_mamba.py
Q-bert's picture
Upload 2 files
8af6a54
raw
history blame
1.31 kB
import math
from typing import Optional , Union
from transformers import PretrainedConfig
class MambaConfig(PretrainedConfig):
model_type = "mamba"
def __init__(
self,
vocab_size=50277,
d_state=16,
d_model=2560,
d_conv=4,
expand=2,
conv_bias=True,
bias=False,
n_layer=64,
dt_rank: Union[int, str] = "auto",
pad_vocab_size_multiple=8,
initializer_range=0.02,
**kwargs,
):
self.vocab_size = vocab_size
self.n_layer= n_layer
self.conv_bias = conv_bias
self.expand = expand
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.d_conv = d_conv
self.d_model = d_model
self.d_state = d_state
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = dt_rank
self.initializer_range = initializer_range
self.bias = bias
if self.dt_rank == 'auto':
self.dt_rank = math.ceil(self.d_model / 16)
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple)
super().__init__(
**kwargs,
)