File size: 18,164 Bytes
0955f14 629ef82 be25621 1dd6f7d 9ae57a2 0955f14 de4b222 0955f14 72c8584 f4a6bfa 1fa91dd 72c8584 1fa91dd de4b222 a73c8da de4b222 a73c8da de4b222 a73c8da de4b222 9ae57a2 f4a6bfa 0955f14 2649aea 0955f14 1dd6f7d 72c8584 0955f14 167b5e5 c7eb097 454586f 72c8584 f4a6bfa d90c994 4b17b1c d90c994 72c8584 d90c994 72c8584 4b17b1c d88f8fc 1fa91dd d0c6814 f4a6bfa 9082726 f4a6bfa 94c3d48 4b10b3e d90c994 4b10b3e f4a6bfa 4b10b3e de4b222 a73c8da de4b222 a73c8da de4b222 a73c8da abd7c69 a73c8da de4b222 9ae57a2 de4b222 a73c8da de4b222 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 |
import os
import pickle
import torch
from mamba_lm import MambaLMConfig, from_pretrained
from mamba_ssm import MambaLMHeadModel
from contextlib import nullcontext
import numpy as np
from functools import partial
import chess
from sklearn.linear_model import LinearRegression
BASE_DIR = "mamba/"
class MambaPlayer:
def __init__(self, model_name: str, move_num_in_gamestate: bool=False, update_contrastive: bool=False, update_linear: bool=False, linear_probe_path: str=None):
self.model_name = model_name
self.move_num_in_gamestate = move_num_in_gamestate
# -----------------------------------------------------------------------------
init_from = "resume" # either 'resume' or a Mamba variant (e.g. 'state-spaces/mamba-1.4b')
out_dir = "out" # ignored if init_from is not 'resume'
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
seed = 1337
compile = False # set to True if using PyTorch 2.0 and Mamba supports it
# -----------------------------------------------------------------------------
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device_type = (
"cuda" if "cuda" in device else "cpu"
) # for later use in torch.autocast
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
ctx = (
nullcontext()
if device_type == "cpu"
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)
# Model initialization
if init_from == "resume":
#ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
ckpt_path = os.path.normpath(f"../chess-mamba-vs-xformer/out/Mamba/{self.model_name}")
checkpoint = torch.load(ckpt_path, map_location=device)
model_config = checkpoint["model_args"]
model = MambaLMHeadModel(model_config)
model.load_state_dict(checkpoint['model'])
elif init_from.startswith('state-spaces'):
model = from_pretrained(init_from).to(device)
else:
raise ValueError("Invalid init_from value")
model.eval()
model.to(device)
if compile and hasattr(torch, 'compile'):
model = torch.compile(model)
# look for the meta pickle in case it is available in the dataset folder
meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
load_meta = os.path.exists(meta_path)
if move_num_in_gamestate and load_meta:
with open(meta_path, "rb") as f:
meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
vocab_size = meta['vocab_size']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
else:
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
for s in stoi:
assert itos[stoi[s]] == s
vocab_size = len(stoi)
print(f"Vocab size {vocab_size}")
encode = lambda s: [stoi[c] for c in s.replace('-', '')]
decode = lambda l: "".join([itos[i] for i in l if i < vocab_size]).replace("OOO", "O-O-O").replace("OO", "O-O")
self.vocab_size = vocab_size
self.encode = encode
self.decode = decode
self.space_tok = encode(' ')[0]
self.dot_tok = encode('.')[0]
self.model = model
self.ctx = ctx
self.device = device
self.move_num = 0
self.hooks = []
self.max_seq_len = 1536
self.move_buckets = [10, 20, 30, 40, float('inf')]
if update_contrastive or update_linear:
self.activations_sum = {}
self.activations_count = {}
if update_linear:
if linear_probe_path and os.path.exists(linear_probe_path):
self.linear_probes = torch.load(linear_probe_path)
else:
self.linear_probes = {}
slef.linear_optimizers = {
layer_idx: {
probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=lr)
for probe_type in ['q_value', 'q_value_delta', 'material_balance']
}
for layer_idx in self.linear_probes
}
if update_contrastive or update_linear:
for i, layer in enumerate(self.model.backbone.layers):
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
for bucket in self.move_buckets}
self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
for bucket in self.move_buckets}
def hook(module, input, output, layer_idx=i):
if isinstance(output, tuple):
tensor_output = output[0]
else:
tensor_output = output
seq_len = tensor_output.shape[1]
bucket = next(b for b in self.move_buckets if self.move_num <= b)
self.activations_sum[layer_idx][bucket]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
self.activations_count[layer_idx][bucket]["current"] += 1
self.hooks.append(layer.register_forward_hook(hook))
if update_linear:
if not linear_probe_path or not os.path.exists(linear_probe_path):
self.linear_probes[i] = {
'q_value': nn.Linear(self.model.config.d_model, 1),
'q_value_delta': nn.Linear(self.model.config.d_model, 1),
'material_balance': nn.Linear(self.model.config.d_model, 1)
}
if update_linear:
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
game_state = game_state.split("\n\n")[-1].strip()
#game_state = ";" + game_state
# Tokenize the game state
encoded_prompt = self.encode(game_state)
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
self.model.eval() # Set the model to evaluation mode
with torch.no_grad():
have_non_space = False
for _ in range(max_new_tokens):
logits = self.model(input_ids).logits[0, -1, :] # Get logits for the last token
# Apply temperature scaling and optionally sample from top k tokens
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
probs = torch.clamp(probs, min=1e-6, max=1.0)
probs = probs / probs.sum()
try:
next_token_id = torch.multinomial(probs, num_samples=1)
except:
return None
if next_token_id == self.space_tok or next_token_id==self.dot_tok:
if have_non_space:
break
else:
have_non_space = True
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
model_response = self.decode(input_ids[0].tolist())
model_response = model_response[len(game_state):].split(";")[0]
return model_response
#def encode(self, text: str):
# Implement the appropriate tokenization for MambaLM
# This could be a simple mapping or a more complex tokenizer
# return [stoi[char] for char in text] # Example
#def decode(self, token_ids: list):
# Implement the appropriate decoding for MambaLM
# return ''.join([itos[id] for id in token_ids]) # Example
def get_move_from_response(self, response: str) -> str:
if not response or len(response) == 0:
return None
# Parse the response to get only the first move
try:
moves = response.split()
first_move = moves[0]
first_move = first_move.lstrip('.') # A patch for a weird phase during training ... doesn't seem to be an issue anymore, but don't see the harm.
return first_move
except:
return None
def get_move(self, board: chess.Board, game_state: str, temperature: float) -> str:
self.move_num = game_state.count('.')
completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
return self.get_move_from_response(completion)
def get_config(self) -> dict:
return {"model": self.model_name}
def update_activations(self, result):
for layer_idx in self.activations_sum:
if "result" == "reset":
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
for bucket in self.move_buckets}
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
for bucket in self.move_buckets}
else:
for bucket in self.move_buckets:
self.activations_sum[layer_idx][bucket][result] += self.activations_sum[layer_idx][bucket]["current"]
self.activations_count[layer_idx][bucket][result] += 1
def save_activations(self, path):
if os.path.exists(path):
with open(path, "rb") as f:
activations_sum, activations_count = pickle.load(f)
else:
activations_sum = {}
activations_count = {}
for layer_idx in self.activations_sum:
for bucket in self.move_buckets:
if self.activations_count[layer_idx][bucket]["current"] == 0:
continue
if layer_idx not in activations_sum:
activations_sum[layer_idx] = {}
activations_count[layer_idx] = {}
if bucket not in activations_sum[layer_idx]:
activations_sum[layer_idx][bucket] = {}
activations_count[layer_idx][bucket] = {}
for category in ["won", "lost"]:
if category not in activations_sum[layer_idx][bucket]:
activations_sum[layer_idx][bucket][category] = np.zeros((1, self.max_seq_len, self.model.config.d_model))
activations_count[layer_idx][bucket][category] = 0
activations_sum[layer_idx][bucket][category] += self.activations_sum[layer_idx][bucket][category]
activations_count[layer_idx][bucket][category] += self.activations_count[layer_idx][bucket][category]
with open(path, "wb") as f:
pickle.dump((activations_sum, activations_count), f)
for layer_idx in self.activations_sum:
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
for bucket in self.move_buckets}
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
for bucket in self.move_buckets}
def apply_contrastive_activations(self, path, weight=1.0):
if os.path.exists(path):
with open(path, "rb") as f:
activations_sum, activations_count = pickle.load(f)
self.contrastive_activations_cache = {}
def hook(module, input, output, layer_idx):
if isinstance(output, tuple):
tensor_output = output[0]
else:
tensor_output = output
seq_len = tensor_output.shape[1]
bucket = next(b for b in self.move_buckets if self.move_num <= b)
# Check cache first
if layer_idx in self.contrastive_activations_cache and bucket in self.contrastive_activations_cache[layer_idx]:
safe_contrastive_activations = self.contrastive_activations_cache[layer_idx][bucket]
else:
won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
contrastive_activations = won_activations - lost_activations
contrastive_activations_tensor = torch.from_numpy(contrastive_activations).to(tensor_output.device)
valid_activations = torch.isfinite(contrastive_activations_tensor)
safe_contrastive_activations = torch.zeros_like(contrastive_activations_tensor)
safe_contrastive_activations[valid_activations] = contrastive_activations_tensor[valid_activations]
# Cache the safe activations
if layer_idx not in self.contrastive_activations_cache:
self.contrastive_activations_cache[layer_idx] = {}
self.contrastive_activations_cache[layer_idx][bucket] = safe_contrastive_activations
tensor_output += safe_contrastive_activations[:, :seq_len, :] * weight
if isinstance(output, tuple):
return tensor_output, output[1]
else:
return tensor_output
for layer_idx in activations_sum:
self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
lambda module, input, output, layer_idx=layer_idx: hook(module, input, output, layer_idx)
))
def update_linear_probe_targets(self, curr_q_value, q_value_delta, material_bal):
bucket = next(b for b in self.move_buckets if self.move_num <= b)
for layer_idx in self.linear_probe_targets:
self.linear_probe_targets[layer_idx][bucket]['q_value'].append(curr_q_value)
self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
def train_linear_probes(self, lr=0.01):
criterion = nn.MSELoss()
for layer_idx in self.linear_probes:
for bucket in self.move_buckets:
if self.activations_count[layer_idx][bucket]['current'] > 0:
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current'] #/ self.activations_count[layer_idx][bucket]['current']).float()
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
if len(y) > 0:
y_pred = self.linear_probes[layer_idx][probe_type](X)
loss = criterion(y_pred, y)
self.linear_optimizers[layer_idx][probe_type].zero_grad()
loss.backward()
self.linear_optimizers[layer_idx][probe_type].step()
# Reset linear_probe_targets after training
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
def save_linear_probe_data(self, path):
torch.save(self.linear_probes, path)
def evaluate_linear_probes(self, board: chess.Board, game_state: str):
self.move_num = game_state.count('.')
bucket = next(b for b in self.move_buckets if self.move_num <= b)
for layer_idx in self.linear_probes:
X = torch.cat(self.activations_sum[layer_idx][bucket]['current'], dim=0)
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
probe = self.linear_probes[layer_idx][probe_type]
prediction = probe(X)
print(f"Layer {layer_idx}, {probe_type}: {prediction.item()}") |