File size: 6,414 Bytes
f8519d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import pickle
import torch
from mamba_lm import MambaLM, MambaLMConfig, from_pretrained
from contextlib import nullcontext
BASE_DIR = "mamba/"
class MambaPlayer:
def __init__(self, model_name: str, move_num_in_gamestate: bool=False):
self.model_name = model_name
self.move_num_in_gamestate = move_num_in_gamestate
# -----------------------------------------------------------------------------
init_from = "resume" # either 'resume' or a Mamba variant (e.g. 'state-spaces/mamba-1.4b')
out_dir = "out" # ignored if init_from is not 'resume'
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
seed = 1337
compile = False # set to True if using PyTorch 2.0 and Mamba supports it
# -----------------------------------------------------------------------------
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device_type = (
"cuda" if "cuda" in device else "cpu"
) # for later use in torch.autocast
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
ctx = (
nullcontext()
if device_type == "cpu"
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)
# Model initialization
if init_from == "resume":
#ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
ckpt_path = os.path.normpath(f"../../mamba.py/out/{self.model_name}")
checkpoint = torch.load(ckpt_path, map_location=device)
model_config = checkpoint["model_args"]
model = MambaLM(model_config)
model.load_state_dict(checkpoint['model'])
elif init_from.startswith('state-spaces'):
model = from_pretrained(init_from).to(device)
else:
raise ValueError("Invalid init_from value")
model.eval()
model.to(device)
if compile and hasattr(torch, 'compile'):
model = torch.compile(model)
# look for the meta pickle in case it is available in the dataset folder
meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
load_meta = os.path.exists(meta_path)
if move_num_in_gamestate and load_meta:
with open(meta_path, "rb") as f:
meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
vocab_size = meta['vocab_size']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
else:
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
for s in stoi:
assert itos[stoi[s]] == s
vocab_size = len(stoi)
print(f"Vocab size {vocab_size}")
encode = lambda s: [stoi[c] for c in s.replace('-', '')]
decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O")
self.vocab_size = vocab_size
self.encode = encode
self.decode = decode
self.model = model
self.ctx = ctx
self.device = device
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
game_state = game_state.split("\n\n")[-1].strip()
#game_state = ";" + game_state
# Tokenize the game state
encoded_prompt = self.encode(game_state)
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
self.model.eval() # Set the model to evaluation mode
with torch.no_grad():
have_non_space = False
for _ in range(max_new_tokens):
logits = self.model(input_ids)[0, -1, :] # Get logits for the last token
# Apply temperature scaling and optionally sample from top k tokens
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
if have_non_space and (next_token_id == 0 or next_token_id==4):
break
else:
have_non_space = True
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
model_response = self.decode(input_ids[0].tolist())
model_response = model_response[len(game_state):].split(";")[0]
return model_response
#def encode(self, text: str):
# Implement the appropriate tokenization for MambaLM
# This could be a simple mapping or a more complex tokenizer
# return [stoi[char] for char in text] # Example
#def decode(self, token_ids: list):
# Implement the appropriate decoding for MambaLM
# return ''.join([itos[id] for id in token_ids]) # Example
def get_move_from_response(self, response: str) -> str:
if not response:
return None
# Parse the response to get only the first move
moves = response.split()
first_move = moves[0]
first_move = first_move.lstrip('.') # A patch for a weird phase during training ... doesn't seem to be an issue anymore, but don't see the harm.
return first_move
def get_move(self, board: str, game_state: str, temperature: float) -> str:
completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
return self.get_move_from_response(completion)
def get_config(self) -> dict:
return {"model": self.model_name}
|