File size: 18,164 Bytes
0955f14 629ef82 be25621 1dd6f7d 9ae57a2 0955f14 de4b222 0955f14 72c8584 f4a6bfa 1fa91dd 72c8584 1fa91dd de4b222 a73c8da de4b222 a73c8da de4b222 a73c8da de4b222 9ae57a2 f4a6bfa 0955f14 2649aea 0955f14 1dd6f7d 72c8584 0955f14 167b5e5 c7eb097 454586f 72c8584 f4a6bfa d90c994 4b17b1c d90c994 72c8584 d90c994 72c8584 4b17b1c d88f8fc 1fa91dd d0c6814 f4a6bfa 9082726 f4a6bfa 94c3d48 4b10b3e d90c994 4b10b3e f4a6bfa 4b10b3e de4b222 a73c8da de4b222 a73c8da de4b222 a73c8da abd7c69 a73c8da de4b222 9ae57a2 de4b222 a73c8da de4b222 |
|
import os
import pickle
import torch
from mamba_lm import MambaLMConfig, from_pretrained
from mamba_ssm import MambaLMHeadModel
from contextlib import nullcontext
import numpy as np
from functools import partial
import chess
from sklearn.linear_model import LinearRegression
BASE_DIR = "mamba/"
class MambaPlayer:
def __init__(self, model_name: str, move_num_in_gamestate: bool=False, update_contrastive: bool=False, update_linear: bool=False, linear_probe_path: str=None):
self.model_name = model_name
self.move_num_in_gamestate = move_num_in_gamestate
# -----------------------------------------------------------------------------
init_from = "resume" # either 'resume' or a Mamba variant (e.g. 'state-spaces/mamba-1.4b')
out_dir = "out" # ignored if init_from is not 'resume'
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32'
seed = 1337
compile = False # set to True if using PyTorch 2.0 and Mamba supports it
# -----------------------------------------------------------------------------
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device_type = (
"cuda" if "cuda" in device else "cpu"
) # for later use in torch.autocast
ptdtype = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}[dtype]
ctx = (
nullcontext()
if device_type == "cpu"
else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)
# Model initialization
if init_from == "resume":
#ckpt_path = os.path.join(BASE_DIR, out_dir, self.model_name)
ckpt_path = os.path.normpath(f"../chess-mamba-vs-xformer/out/Mamba/{self.model_name}")
checkpoint = torch.load(ckpt_path, map_location=device)
model_config = checkpoint["model_args"]
model = MambaLMHeadModel(model_config)
model.load_state_dict(checkpoint['model'])
elif init_from.startswith('state-spaces'):
model = from_pretrained(init_from).to(device)
else:
raise ValueError("Invalid init_from value")
model.eval()
model.to(device)
if compile and hasattr(torch, 'compile'):
model = torch.compile(model)
# look for the meta pickle in case it is available in the dataset folder
meta_path = os.path.join(BASE_DIR, "out", "meta.pkl")
load_meta = os.path.exists(meta_path)
if move_num_in_gamestate and load_meta:
with open(meta_path, "rb") as f:
meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
vocab_size = meta['vocab_size']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
else:
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27}
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='}
for s in stoi:
assert itos[stoi[s]] == s
vocab_size = len(stoi)
print(f"Vocab size {vocab_size}")
encode = lambda s: [stoi[c] for c in s.replace('-', '')]
decode = lambda l: "".join([itos[i] for i in l if i < vocab_size]).replace("OOO", "O-O-O").replace("OO", "O-O")
self.vocab_size = vocab_size
self.encode = encode
self.decode = decode
self.space_tok = encode(' ')[0]
self.dot_tok = encode('.')[0]
self.model = model
self.ctx = ctx
self.device = device
self.move_num = 0
self.hooks = []
self.max_seq_len = 1536
self.move_buckets = [10, 20, 30, 40, float('inf')]
if update_contrastive or update_linear:
self.activations_sum = {}
self.activations_count = {}
if update_linear:
if linear_probe_path and os.path.exists(linear_probe_path):
self.linear_probes = torch.load(linear_probe_path)
else:
self.linear_probes = {}
slef.linear_optimizers = {
layer_idx: {
probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=lr)
for probe_type in ['q_value', 'q_value_delta', 'material_balance']
}
for layer_idx in self.linear_probes
}
if update_contrastive or update_linear:
for i, layer in enumerate(self.model.backbone.layers):
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
for bucket in self.move_buckets}
self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
for bucket in self.move_buckets}
def hook(module, input, output, layer_idx=i):
if isinstance(output, tuple):
tensor_output = output[0]
else:
tensor_output = output
seq_len = tensor_output.shape[1]
bucket = next(b for b in self.move_buckets if self.move_num <= b)
self.activations_sum[layer_idx][bucket]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
self.activations_count[layer_idx][bucket]["current"] += 1
self.hooks.append(layer.register_forward_hook(hook))
if update_linear:
if not linear_probe_path or not os.path.exists(linear_probe_path):
self.linear_probes[i] = {
'q_value': nn.Linear(self.model.config.d_model, 1),
'q_value_delta': nn.Linear(self.model.config.d_model, 1),
'material_balance': nn.Linear(self.model.config.d_model, 1)
}
if update_linear:
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
game_state = game_state.split("\n\n")[-1].strip()
#game_state = ";" + game_state
# Tokenize the game state
encoded_prompt = self.encode(game_state)
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
self.model.eval() # Set the model to evaluation mode
with torch.no_grad():
have_non_space = False
for _ in range(max_new_tokens):
logits = self.model(input_ids).logits[0, -1, :] # Get logits for the last token
# Apply temperature scaling and optionally sample from top k tokens
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
probs = torch.clamp(probs, min=1e-6, max=1.0)
probs = probs / probs.sum()
try:
next_token_id = torch.multinomial(probs, num_samples=1)
except:
return None
if next_token_id == self.space_tok or next_token_id==self.dot_tok:
if have_non_space:
break
else:
have_non_space = True
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
model_response = self.decode(input_ids[0].tolist())
model_response = model_response[len(game_state):].split(";")[0]
return model_response
#def encode(self, text: str):
# Implement the appropriate tokenization for MambaLM
# This could be a simple mapping or a more complex tokenizer
# return [stoi[char] for char in text] # Example
#def decode(self, token_ids: list):
# Implement the appropriate decoding for MambaLM
# return ''.join([itos[id] for id in token_ids]) # Example
def get_move_from_response(self, response: str) -> str:
if not response or len(response) == 0:
return None
# Parse the response to get only the first move
try:
moves = response.split()
first_move = moves[0]
first_move = first_move.lstrip('.') # A patch for a weird phase during training ... doesn't seem to be an issue anymore, but don't see the harm.
return first_move
except:
return None
def get_move(self, board: chess.Board, game_state: str, temperature: float) -> str:
self.move_num = game_state.count('.')
completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
return self.get_move_from_response(completion)
def get_config(self) -> dict:
return {"model": self.model_name}
def update_activations(self, result):
for layer_idx in self.activations_sum:
if "result" == "reset":
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
for bucket in self.move_buckets}
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
for bucket in self.move_buckets}
else:
for bucket in self.move_buckets:
self.activations_sum[layer_idx][bucket][result] += self.activations_sum[layer_idx][bucket]["current"]
self.activations_count[layer_idx][bucket][result] += 1
def save_activations(self, path):
if os.path.exists(path):
with open(path, "rb") as f:
activations_sum, activations_count = pickle.load(f)
else:
activations_sum = {}
activations_count = {}
for layer_idx in self.activations_sum:
for bucket in self.move_buckets:
if self.activations_count[layer_idx][bucket]["current"] == 0:
continue
if layer_idx not in activations_sum:
activations_sum[layer_idx] = {}
activations_count[layer_idx] = {}
if bucket not in activations_sum[layer_idx]:
activations_sum[layer_idx][bucket] = {}
activations_count[layer_idx][bucket] = {}
for category in ["won", "lost"]:
if category not in activations_sum[layer_idx][bucket]:
activations_sum[layer_idx][bucket][category] = np.zeros((1, self.max_seq_len, self.model.config.d_model))
activations_count[layer_idx][bucket][category] = 0
activations_sum[layer_idx][bucket][category] += self.activations_sum[layer_idx][bucket][category]
activations_count[layer_idx][bucket][category] += self.activations_count[layer_idx][bucket][category]
with open(path, "wb") as f:
pickle.dump((activations_sum, activations_count), f)
for layer_idx in self.activations_sum:
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
for bucket in self.move_buckets}
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
for bucket in self.move_buckets}
def apply_contrastive_activations(self, path, weight=1.0):
if os.path.exists(path):
with open(path, "rb") as f:
activations_sum, activations_count = pickle.load(f)
self.contrastive_activations_cache = {}
def hook(module, input, output, layer_idx):
if isinstance(output, tuple):
tensor_output = output[0]
else:
tensor_output = output
seq_len = tensor_output.shape[1]
bucket = next(b for b in self.move_buckets if self.move_num <= b)
# Check cache first
if layer_idx in self.contrastive_activations_cache and bucket in self.contrastive_activations_cache[layer_idx]:
safe_contrastive_activations = self.contrastive_activations_cache[layer_idx][bucket]
else:
won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
contrastive_activations = won_activations - lost_activations
contrastive_activations_tensor = torch.from_numpy(contrastive_activations).to(tensor_output.device)
valid_activations = torch.isfinite(contrastive_activations_tensor)
safe_contrastive_activations = torch.zeros_like(contrastive_activations_tensor)
safe_contrastive_activations[valid_activations] = contrastive_activations_tensor[valid_activations]
# Cache the safe activations
if layer_idx not in self.contrastive_activations_cache:
self.contrastive_activations_cache[layer_idx] = {}
self.contrastive_activations_cache[layer_idx][bucket] = safe_contrastive_activations
tensor_output += safe_contrastive_activations[:, :seq_len, :] * weight
if isinstance(output, tuple):
return tensor_output, output[1]
else:
return tensor_output
for layer_idx in activations_sum:
self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
lambda module, input, output, layer_idx=layer_idx: hook(module, input, output, layer_idx)
))
def update_linear_probe_targets(self, curr_q_value, q_value_delta, material_bal):
bucket = next(b for b in self.move_buckets if self.move_num <= b)
for layer_idx in self.linear_probe_targets:
self.linear_probe_targets[layer_idx][bucket]['q_value'].append(curr_q_value)
self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
def train_linear_probes(self, lr=0.01):
criterion = nn.MSELoss()
for layer_idx in self.linear_probes:
for bucket in self.move_buckets:
if self.activations_count[layer_idx][bucket]['current'] > 0:
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current'] #/ self.activations_count[layer_idx][bucket]['current']).float()
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
if len(y) > 0:
y_pred = self.linear_probes[layer_idx][probe_type](X)
loss = criterion(y_pred, y)
self.linear_optimizers[layer_idx][probe_type].zero_grad()
loss.backward()
self.linear_optimizers[layer_idx][probe_type].step()
# Reset linear_probe_targets after training
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
def save_linear_probe_data(self, path):
torch.save(self.linear_probes, path)
def evaluate_linear_probes(self, board: chess.Board, game_state: str):
self.move_num = game_state.count('.')
bucket = next(b for b in self.move_buckets if self.move_num <= b)
for layer_idx in self.linear_probes:
X = torch.cat(self.activations_sum[layer_idx][bucket]['current'], dim=0)
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
probe = self.linear_probes[layer_idx][probe_type]
prediction = probe(X)
print(f"Layer {layer_idx}, {probe_type}: {prediction.item()}") |