File size: 1,312 Bytes
9fad7fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import math
from typing import Optional , Union
from transformers import PretrainedConfig
class MambaConfig(PretrainedConfig):
model_type = "mamba"
def __init__(
self,
vocab_size=50277,
d_state=16,
d_model=2560,
d_conv=4,
expand=2,
conv_bias=True,
bias=False,
n_layer=64,
dt_rank: Union[int, str] = "auto",
pad_vocab_size_multiple=8,
initializer_range=0.02,
**kwargs,
):
self.vocab_size = vocab_size
self.n_layer= n_layer
self.conv_bias = conv_bias
self.expand = expand
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.d_conv = d_conv
self.d_model = d_model
self.d_state = d_state
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = dt_rank
self.initializer_range = initializer_range
self.bias = bias
if self.dt_rank == 'auto':
self.dt_rank = math.ceil(self.d_model / 16)
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple)
super().__init__(
**kwargs,
) |