HaileyStorm
commited on
Commit
•
07f1096
1
Parent(s):
5e634b7
Upload chess-gpt-eval/mamba_module.py with huggingface_hub
Browse files
chess-gpt-eval/mamba_module.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import os
|
2 |
import pickle
|
3 |
import torch
|
4 |
-
from mamba_lm import
|
|
|
5 |
from contextlib import nullcontext
|
6 |
|
7 |
BASE_DIR = "mamba/"
|
@@ -41,10 +42,10 @@ class MambaPlayer:
|
|
41 |
# Model initialization
|
42 |
if init_from == "resume":
|
43 |
#ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
|
44 |
-
ckpt_path = os.path.normpath(f"
|
45 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
46 |
model_config = checkpoint["model_args"]
|
47 |
-
model =
|
48 |
model.load_state_dict(checkpoint['model'])
|
49 |
elif init_from.startswith('state-spaces'):
|
50 |
model = from_pretrained(init_from).to(device)
|
@@ -96,7 +97,7 @@ class MambaPlayer:
|
|
96 |
with torch.no_grad():
|
97 |
have_non_space = False
|
98 |
for _ in range(max_new_tokens):
|
99 |
-
logits = self.model(input_ids)[0, -1, :] # Get logits for the last token
|
100 |
|
101 |
# Apply temperature scaling and optionally sample from top k tokens
|
102 |
logits = logits / temperature
|
|
|
1 |
import os
|
2 |
import pickle
|
3 |
import torch
|
4 |
+
from mamba_lm import MambaLMConfig, from_pretrained
|
5 |
+
from mamba_ssm import MambaLMHeadModel
|
6 |
from contextlib import nullcontext
|
7 |
|
8 |
BASE_DIR = "mamba/"
|
|
|
42 |
# Model initialization
|
43 |
if init_from == "resume":
|
44 |
#ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
|
45 |
+
ckpt_path = os.path.normpath(f"../chess-mamba-vs-xformer/out/Mamba/{self.model_name}")
|
46 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
47 |
model_config = checkpoint["model_args"]
|
48 |
+
model = MambaLMHeadModel(model_config)
|
49 |
model.load_state_dict(checkpoint['model'])
|
50 |
elif init_from.startswith('state-spaces'):
|
51 |
model = from_pretrained(init_from).to(device)
|
|
|
97 |
with torch.no_grad():
|
98 |
have_non_space = False
|
99 |
for _ in range(max_new_tokens):
|
100 |
+
logits = self.model(input_ids).logits[0, -1, :] # Get logits for the last token
|
101 |
|
102 |
# Apply temperature scaling and optionally sample from top k tokens
|
103 |
logits = logits / temperature
|