HaileyStorm
commited on
Upload chess-mamba-vs-xformer/mamba_lm.py with huggingface_hub
Browse files
chess-mamba-vs-xformer/mamba_lm.py
CHANGED
@@ -5,7 +5,10 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
-
from mamba import Mamba, MambaConfig, RMSNorm
|
|
|
|
|
|
|
9 |
|
10 |
"""
|
11 |
|
@@ -22,15 +25,18 @@ class MambaLMConfig(MambaConfig):
|
|
22 |
pad_vocab_size_multiple: int = 8
|
23 |
|
24 |
def __post_init__(self):
|
25 |
-
|
|
|
26 |
|
27 |
#if self.vocab_size % self.pad_vocab_size_multiple != 0:
|
28 |
# self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
|
29 |
|
30 |
def to_mamba_config(self) -> MambaConfig:
|
31 |
-
mamba_config_fields = {field.name for field in fields(MambaConfig)}
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
# adapted from https://github.com/johnma2006/mamba-minimal
|
36 |
def from_pretrained(name: str):
|
@@ -65,7 +71,8 @@ def from_pretrained(name: str):
|
|
65 |
config_data = load_config_hf(name)
|
66 |
config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
|
67 |
|
68 |
-
model = MambaLM(config)
|
|
|
69 |
|
70 |
# copy weights
|
71 |
state_dict = load_state_dict_hf(name)
|
@@ -90,7 +97,7 @@ class MambaLM(nn.Module):
|
|
90 |
self.config = lm_config.to_mamba_config()
|
91 |
|
92 |
self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
|
93 |
-
self.mamba = Mamba(self.config)
|
94 |
self.norm_f = RMSNorm(self.config.d_model)
|
95 |
|
96 |
self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
|
|
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
+
#from mamba import Mamba, MambaConfig, RMSNorm
|
9 |
+
from mamba_ssm import MambaLMHeadModel
|
10 |
+
from mamba_ssm.models.config_mamba import MambaConfig
|
11 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm
|
12 |
|
13 |
"""
|
14 |
|
|
|
25 |
pad_vocab_size_multiple: int = 8
|
26 |
|
27 |
def __post_init__(self):
|
28 |
+
pass
|
29 |
+
#super().__post_init__()
|
30 |
|
31 |
#if self.vocab_size % self.pad_vocab_size_multiple != 0:
|
32 |
# self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)
|
33 |
|
34 |
def to_mamba_config(self) -> MambaConfig:
|
35 |
+
#mamba_config_fields = {field.name for field in fields(MambaConfig)}
|
36 |
+
#print(mamba_config_fields)
|
37 |
+
#filtered_dict = {k: v for k, v in asdict(self).items() if k in mamba_config_fields}
|
38 |
+
#return MambaConfig(**filtered_dict)
|
39 |
+
return MambaConfig(d_model=self.d_model, n_layer=self.n_layer, vocab_size=self.vocab_size, ssm_cfg=self.ssm_cfg)
|
40 |
|
41 |
# adapted from https://github.com/johnma2006/mamba-minimal
|
42 |
def from_pretrained(name: str):
|
|
|
71 |
config_data = load_config_hf(name)
|
72 |
config = MambaLMConfig(d_model=config_data['d_model'], n_layers=config_data['n_layer'], vocab_size=config_data['vocab_size'])
|
73 |
|
74 |
+
#model = MambaLM(config)
|
75 |
+
model = MambaLMHeadModel(config)
|
76 |
|
77 |
# copy weights
|
78 |
state_dict = load_state_dict_hf(name)
|
|
|
97 |
self.config = lm_config.to_mamba_config()
|
98 |
|
99 |
self.embedding = nn.Embedding(self.lm_config.vocab_size, self.config.d_model)
|
100 |
+
self.mamba = Mamba(**self.config.__dict__)
|
101 |
self.norm_f = RMSNorm(self.config.d_model)
|
102 |
|
103 |
self.lm_head = nn.Linear(self.config.d_model, self.lm_config.vocab_size, bias=False)
|