File size: 21,584 Bytes
0955f14 629ef82 be25621 1dd6f7d 9ae57a2 f2ce2e2 7878a45 03f501f f6ed371 0955f14 de4b222 0955f14 72c8584 f4a6bfa 1fa91dd 5238925 1fa91dd de4b222 a73c8da de4b222 164b5fe de4b222 4560751 de4b222 4719b37 de4b222 7136964 de4b222 9ae57a2 e8aba5c 432e67d e8aba5c 7878a45 f67254a 6d05ca3 f4a6bfa 0955f14 4560751 0955f14 2649aea 0955f14 4560751 0955f14 1dd6f7d 72c8584 0955f14 167b5e5 c7eb097 45d2b20 0c33a38 454586f 72c8584 f4a6bfa d90c994 4b17b1c d90c994 72c8584 d90c994 72c8584 0c33a38 72c8584 4b17b1c d88f8fc 1fa91dd 0c33a38 d0c6814 f4a6bfa 9082726 f4a6bfa 94c3d48 4b10b3e d90c994 4b10b3e f4a6bfa 4b10b3e de4b222 432e67d bf84c14 0fbfa94 5472f94 c3539e3 bf84c14 a73c8da f67254a bf84c14 a73c8da de4b222 d29de63 de4b222 a73c8da abd7c69 a73c8da f67254a a73c8da 5238925 f67254a a73c8da de4b222 9ae57a2 de4b222 9284512 f67254a a73c8da de4b222 a7153f4 de4b222 f6ed371 de4b222 b2567ad de4b222 f728647 de4b222 f728647 f6ed371 9284512 f6ed371 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 |
import os
import pickle
import torch
from mamba_lm import MambaLMConfig, from_pretrained
from mamba_ssm import MambaLMHeadModel
from contextlib import nullcontext
import numpy as np
from functools import partial
import chess
from sklearn.linear_model import LinearRegression
import torch.nn as nn
import torch.optim as optim
import wandb
import math
import json
BASE_DIR = "mamba/"
class MambaPlayer:
def __init__(self, model_name: str, move_num_in_gamestate: bool=False, update_contrastive: bool=False, update_linear: bool=False, linear_probe_path: str=None):
self.model_name = model_name
self.move_num_in_gamestate = move_num_in_gamestate
# -----------------------------------------------------------------------------
init_from = "resume" # either 'resume' or a Mamba variant (e.g. 'state-spaces/mamba-1.4b')
out_dir = "out" # ignored if init_from is not 'resume'
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
seed = 1337
compile = False # set to True if using PyTorch 2.0 and Mamba supports it
# -----------------------------------------------------------------------------
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device_type = (
"cuda" if "cuda" in device else "cpu"
) # for later use in torch.autocast
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
ctx = (
nullcontext()
if device_type == "cpu"
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)
# Model initialization
if init_from == "resume":
#ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
ckpt_path = os.path.normpath(f"../chess-mamba-vs-xformer/out/Mamba/{self.model_name}")
checkpoint = torch.load(ckpt_path, map_location=device)
model_config = checkpoint["model_args"]
model = MambaLMHeadModel(model_config)
model.load_state_dict(checkpoint['model'])
elif init_from.startswith('state-spaces'):
model = from_pretrained(init_from).to(device)
else:
raise ValueError("Invalid init_from value")
model.eval()
model.to(device)
if compile and hasattr(torch, 'compile'):
model = torch.compile(model)
# look for the meta pickle in case it is available in the dataset folder
meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
load_meta = os.path.exists(meta_path)
if move_num_in_gamestate and load_meta:
with open(meta_path, "rb") as f:
meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
vocab_size = meta['vocab_size']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
else:
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
for s in stoi:
assert itos[stoi[s]] == s
vocab_size = len(stoi)
print(f"Vocab size {vocab_size}")
encode = lambda s: [stoi[c] for c in s.replace('-', '')]
decode = lambda l: "".join([itos[i] for i in l if i < vocab_size]).replace("OOO", "O-O-O").replace("OO", "O-O")
self.vocab_size = vocab_size
self.encode = encode
self.decode = decode
self.space_tok = encode(' ')[0]
self.dot_tok = encode('.')[0]
self.model = model
self.ctx = ctx
self.device = device
self.move_num = 0
self.hooks = []
self.max_seq_len = 1536
#self.move_buckets = [10, 20, 30, 40, float('inf')]
self.move_buckets = [float('inf')]
if update_contrastive or update_linear:
self.activations_sum = {}
self.activations_count = {}
if update_linear:
if linear_probe_path and os.path.exists(linear_probe_path):
self.linear_probes = torch.load(linear_probe_path)
else:
self.linear_probes = {}
if update_contrastive or update_linear:
linear_size = self.model.config.d_model * 8 #self.model.config.d_model * self.max_seq_len
for i, layer in enumerate(self.model.backbone.layers):
self.activations_sum[i] = {bucket: {"won": np.zeros((1, 8, self.model.config.d_model)),
"lost": np.zeros((1, 8, self.model.config.d_model)),
"current": np.zeros((1, 8, self.model.config.d_model))}
for bucket in self.move_buckets}
self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
for bucket in self.move_buckets}
def hook(module, input, output, layer_idx=i):
if isinstance(output, tuple):
tensor_output = output[0]
else:
tensor_output = output
seq_len = tensor_output.shape[1]
bucket = next(b for b in self.move_buckets if self.move_num <= b)
self.activations_sum[layer_idx][bucket]["current"][:, :min(8, self.seq_len), :] += tensor_output.detach().cpu().numpy()[:, :self.seq_len, :][:, -8:, :]
self.activations_count[layer_idx][bucket]["current"] += 1
self.hooks.append(layer.register_forward_hook(hook))
if update_linear:
if not linear_probe_path or not os.path.exists(linear_probe_path):
self.linear_probes[i] = {
'q_value': nn.Linear(linear_size, 1),
'q_value_delta': nn.Linear(linear_size, 1),
'material_balance': nn.Linear(linear_size, 1)
}
if update_linear:
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
self.linear_optimizers = {
layer_idx: {
probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=0.01)
for probe_type in ['q_value', 'q_value_delta', 'material_balance']
}
for layer_idx in self.linear_probes
}
wandb.init(project="mamba_linear_probes", name=f"mamba_linear_probes")
self.wandb_step = 0
self.linear_save_ct = 0
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
game_state = game_state.split("\n\n")[-1].strip()
#game_state = ";" + game_state
# Tokenize the game state
encoded_prompt = self.encode(game_state)
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
self.seq_len = input_ids[0].size(dim=0)
self.model.eval() # Set the model to evaluation mode
with torch.no_grad():
have_non_space = False
for _ in range(max_new_tokens):
logits = self.model(input_ids).logits[0, -1, :] # Get logits for the last token
# Apply temperature scaling and optionally sample from top k tokens
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
probs = torch.clamp(probs, min=1e-6, max=1.0)
probs = probs / probs.sum()
try:
next_token_id = torch.multinomial(probs, num_samples=1)
except:
return None
if next_token_id == self.space_tok or next_token_id==self.dot_tok:
if have_non_space:
break
else:
have_non_space = True
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
self.seq_len += 1
model_response = self.decode(input_ids[0].tolist())
model_response = model_response[len(game_state):].split(";")[0]
return model_response
#def encode(self, text: str):
# Implement the appropriate tokenization for MambaLM
# This could be a simple mapping or a more complex tokenizer
# return [stoi[char] for char in text] # Example
#def decode(self, token_ids: list):
# Implement the appropriate decoding for MambaLM
# return ''.join([itos[id] for id in token_ids]) # Example
def get_move_from_response(self, response: str) -> str:
if not response or len(response) == 0:
return None
# Parse the response to get only the first move
try:
moves = response.split()
first_move = moves[0]
first_move = first_move.lstrip('.') # A patch for a weird phase during training ... doesn't seem to be an issue anymore, but don't see the harm.
return first_move
except:
return None
def get_move(self, board: chess.Board, game_state: str, temperature: float) -> str:
self.move_num = game_state.count('.')
completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
return self.get_move_from_response(completion)
def get_config(self) -> dict:
return {"model": self.model_name}
def update_activations(self, result):
for layer_idx in self.activations_sum:
if result == "reset":
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, 8, self.model.config.d_model)),
"lost": np.zeros((1, 8, self.model.config.d_model)),
"current": np.zeros((1, 8, self.model.config.d_model))}
for bucket in self.move_buckets}
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
for bucket in self.move_buckets}
else:
for bucket in self.move_buckets:
self.activations_sum[layer_idx][bucket][result] += self.activations_sum[layer_idx][bucket]["current"]
self.activations_count[layer_idx][bucket][result] += 1
def save_activations(self, path):
if os.path.exists(path):
with open(path, "rb") as f:
activations_sum, activations_count = pickle.load(f)
else:
activations_sum = {}
activations_count = {}
for layer_idx in self.activations_sum:
for bucket in self.move_buckets:
if self.activations_count[layer_idx][bucket]["current"] == 0:
continue
if layer_idx not in activations_sum:
activations_sum[layer_idx] = {}
activations_count[layer_idx] = {}
if bucket not in activations_sum[layer_idx]:
activations_sum[layer_idx][bucket] = {}
activations_count[layer_idx][bucket] = {}
for category in ["won", "lost"]:
if category not in activations_sum[layer_idx][bucket]:
activations_sum[layer_idx][bucket][category] = np.zeros((1, 8, self.model.config.d_model))
activations_count[layer_idx][bucket][category] = 0
activations_sum[layer_idx][bucket][category] += self.activations_sum[layer_idx][bucket][category]
activations_count[layer_idx][bucket][category] += self.activations_count[layer_idx][bucket][category]
with open(path, "wb") as f:
pickle.dump((activations_sum, activations_count), f)
for layer_idx in self.activations_sum:
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, 8, self.model.config.d_model)),
"lost": np.zeros((1, 8, self.model.config.d_model)),
"current": np.zeros((1, 8, self.model.config.d_model))}
for bucket in self.move_buckets}
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
for bucket in self.move_buckets}
def apply_contrastive_activations(self, path, weight=1.0):
if os.path.exists(path):
with open(path, "rb") as f:
activations_sum, activations_count = pickle.load(f)
self.contrastive_activations_cache = {}
def hook(module, input, output, layer_idx):
if isinstance(output, tuple):
tensor_output = output[0]
else:
tensor_output = output
seq_len = tensor_output.shape[1]
bucket = next(b for b in self.move_buckets if self.move_num <= b)
# Check cache first
if layer_idx in self.contrastive_activations_cache and bucket in self.contrastive_activations_cache[layer_idx]:
safe_contrastive_activations = self.contrastive_activations_cache[layer_idx][bucket]
else:
won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
contrastive_activations = won_activations - lost_activations
contrastive_activations_tensor = torch.from_numpy(contrastive_activations).to(tensor_output.device)
valid_activations = torch.isfinite(contrastive_activations_tensor)
safe_contrastive_activations = torch.zeros_like(contrastive_activations_tensor)
safe_contrastive_activations[valid_activations] = contrastive_activations_tensor[valid_activations]
# Cache the safe activations
if layer_idx not in self.contrastive_activations_cache:
self.contrastive_activations_cache[layer_idx] = {}
self.contrastive_activations_cache[layer_idx][bucket] = safe_contrastive_activations
tensor_output += safe_contrastive_activations[:, :seq_len, :] * weight
if isinstance(output, tuple):
return tensor_output, output[1]
else:
return tensor_output
for layer_idx in activations_sum:
self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
lambda module, input, output, layer_idx=layer_idx: hook(module, input, output, layer_idx)
))
def update_linear_probe_targets(self, curr_q_value, q_value_delta, material_bal):
bucket = next(b for b in self.move_buckets if self.move_num <= b)
for layer_idx in self.linear_probe_targets:
self.linear_probe_targets[layer_idx][bucket]['q_value'].append(curr_q_value)
self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
def train_linear_probes(self):
def get_lr(it):
warmup_iters = 0 #300 * 43
lr_decay_iters = 3000 * 43
learning_rate = 0.0000075
min_lr = 0.00000075
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
criterion = nn.MSELoss()
self.wandb_step += 1
lr = get_lr(self.wandb_step)
for layer_idx in self.linear_probes:
for bucket in self.move_buckets:
if self.activations_count[layer_idx][bucket]['current'] > 0:
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1) #/ self.activations_count[layer_idx][bucket]['current']).float()
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
if len(y) > 0:
y_pred = self.linear_probes[layer_idx][probe_type](X)
loss = criterion(y_pred, y)
for param_group in self.linear_optimizers[layer_idx][probe_type].param_groups:
param_group['lr'] = lr
self.linear_optimizers[layer_idx][probe_type].zero_grad()
loss.backward()
self.linear_optimizers[layer_idx][probe_type].step()
#wandb.log({f"{probe_type}/layer_{layer_idx}_{bucket}_loss": loss.item()})
wandb.log({
"etc/lr": lr,
f"{probe_type}/layer_{layer_idx}_loss": loss.item()
}, step=self.wandb_step)
# Reset linear_probe_targets after training
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
def save_linear_probe_data(self, path):
self.linear_save_ct += 25
wandb.log({
"etc/games": self.linear_save_ct
}, step=self.wandb_step)
torch.save(self.linear_probes, path)
def evaluate_linear_probes(self, board: chess.Board):
self.move_num = board.fullmove_number
bucket = next(b for b in self.move_buckets if self.move_num <= b)
# Create a dictionary to store the statistics for the current move
probe_stats = {probe_type: {layer_idx: {self.move_num: None} for layer_idx in self.linear_probes} for probe_type in ['q_value', 'q_value_delta', 'material_balance']}
for layer_idx in self.linear_probes:
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
probe = self.linear_probes[layer_idx][probe_type]
prediction = probe(X).item()
#print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
# Calculate the percentage accuracy based on the probe type
if probe_type == 'q_value':
accuracy = 1 - abs(prediction - target) / 2 # Q-value range: -1 to 1
elif probe_type == 'q_value_delta':
accuracy = 1 - abs(prediction - target) / 4 # Q-value delta range: -2 to 2
else: # material_balance
max_range = 35 # Adjust this value based on the expected range of material balance
accuracy = 1 - min(abs(prediction - target) / max_range, 1)
# Store the accuracy in the probe_stats dictionary for the current move
probe_stats[probe_type][layer_idx][self.move_num] = accuracy
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
# Append the probe_stats to the file
with open('probe_stats.json', 'a') as f:
json.dump(probe_stats, f)
f.write('\n') # Add a newline separator between moves
|