import os import time import math import pickle from contextlib import nullcontext import numpy as np import torch import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import pyarrow.parquet as pq import random from torch.utils.data import Dataset, DataLoader import glob # ----------------------------------------------------------------------------- # default config values designed for Mamba model training # I/O out_dir = 'out' eval_interval = 2000 log_interval = 1 eval_iters = 5 eval_only = False always_save_checkpoint = True init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name # wandb logging wandb_log = False wandb_project = 'mamba' wandb_run_name = 'mamba_run' # modify as needed # data dataset = 'chess' # specify your dataset gradient_accumulation_steps = 5 * 8 batch_size = 12 base_batch_size = batch_size effective_batch_size = batch_size max_seq_len = 1024 # For xformer, this is the block size train_file_update_interval = 7 # model model_type = 'mamba' # TODO: add 'xformer' type / model paramers. move model imports to after exec() (when these values finalized) n_layer = 12 d_model = 768 dt_rank = 'auto' d_state = 16 expand_factor = 2 bias = False conv_bias = True pscan = True vocab_size = 32 move_num_in_gamestate = True # xformer-specific params. Note that n_layer, vocab_size, move_num_in_gamestate, and bias are shared by both model types n_head = 12 n_embd = 768 dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ # optimizer settings learning_rate = 6e-4 max_iters = 600000 # max_iters is for auto-stopping end of stable phase weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 grad_clip = 0.5 auto_clip = False auto_clip_max = 0.5 auto_clip_min = 3.333e-3 grad_clip_start_size = 100 grad_clip_max_size = 500 grad_clip_percentile = 10 # learning rate decay settings decay_lr = True warmup_iters = 2000 min_lr = 6e-5 # DDP settings backend = 'nccl' # system device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32' compile = False # set to True if using PyTorch 2.0 # ----------------------------------------------------------------------------- config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open('configurator.py').read()) # overrides from command line or config file config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- anneal_checkpoint = 'anneal/ckpt.pt' anneal_dir = os.path.join(out_dir, 'anneal/') anneal_start_iters = None # Set at init anneal_decay_iters = None # Set at init if model_type == 'mamba': from mamba_lm import MambaLM, MambaLMConfig from mamba_ssm import MambaLMHeadModel model_config = MambaLMConfig( d_model=d_model, #n_layers=n_layer, n_layer=n_layer, ssm_cfg={ 'dt_rank': dt_rank, 'd_state': d_state, #'expand_factor': expand_factor, 'bias': bias, 'conv_bias':conv_bias, #'pscan':pscan, }, vocab_size=vocab_size, pad_vocab_size_multiple=1 ).to_mamba_config() elif model_type == 'xformer': from xformer import GPTConfig, GPT model_config = GPTConfig( n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=max_seq_len, bias=bias, vocab_size=vocab_size, dropout=dropout) else: print(f"Unknown model_type {model_type}.") exit() # DDP and other initializations ddp = int(os.environ.get('RANK', -1)) != -1 if ddp: init_process_group(backend=backend) ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_world_size = int(os.environ['WORLD_SIZE']) device = f'cuda:{ddp_local_rank}' torch.cuda.set_device(device) master_process = ddp_rank == 0 seed_offset = ddp_rank assert gradient_accumulation_steps % ddp_world_size == 0 gradient_accumulation_steps //= ddp_world_size else: master_process = True seed_offset = 0 ddp_world_size = 1 if master_process: os.makedirs(out_dir, exist_ok=True) os.makedirs(anneal_dir, exist_ok=True) torch.manual_seed(1337 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True device_type = 'cuda' if 'cuda' in device else 'cpu' ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16}[dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # poor man's data loader data_dir = os.path.join('data', dataset) current_train_file_index = 0 train_files = glob.glob(os.path.join(data_dir, 'train*.parquet')) + \ glob.glob(os.path.join(data_dir, 'stable*.parquet')) + \ glob.glob(os.path.join(data_dir, 'anneal*.parquet')) train_datasets = [] print("Loading dataset...") for f in train_files: dataset = pq.read_table(f).to_pandas() dataset = dataset[dataset['tokenized'].apply(len) >= 8] train_datasets.append(dataset) print('.',end='',flush=True) print("\nLoaded.") #val_data = pq.read_table(os.path.join(data_dir, 'val.parquet')).to_pandas() #val_data = val_data[val_data['tokenized'].apply(len) >= 8] truncated_games_count = 0 total_games_count = 0 games_seen = 0 tokens_seen = 0 tokens_seen_padded = 0 def get_batch(split): global truncated_games_count, total_games_count, current_train_file_index, tokens_seen, tokens_seen_padded # Randomly select batch_size games dataset = train_datasets[current_train_file_index] if split == 'train' else None # else val_data # Use the correct DataFrame based on the split sample_df = dataset.sample(batch_size) games = sample_df['tokenized'].tolist() # Prepare sequences tensor for the batch max_length_in_batch = min(max(len(game) for game in games), max_seq_len) pad_to = max_length_in_batch #if model_type == 'mamba' else max_seq_len sequences = torch.zeros((batch_size, pad_to), dtype=torch.int64) for i, game in enumerate(games): total_games_count += 1 game_len = min(len(game), pad_to) tokens_seen += game_len tokens_seen_padded += pad_to sequences[i, :game_len] = torch.tensor(game[:game_len], dtype=torch.int64) if (total_games_count // batch_size) % train_file_update_interval == 0: current_train_file_index = random.randint(0, len(train_files) - 1) # print(f"Switched to file: {train_files[current_train_file_index]}") if device_type == 'cuda': sequences = sequences.pin_memory().to(device, non_blocking=True) else: sequences = sequences.to(device) return sequences, max_length_in_batch # init these up here, can override if init_from='resume' (i.e. from a checkpoint) iter_num = 0 best_val_loss = 1e9 # attempt to derive vocab_size from the dataset meta_path = os.path.join(data_dir, 'meta.pkl') meta_vocab_size = None if not move_num_in_gamestate: meta_vocab_size = 28 elif os.path.exists(meta_path): with open(meta_path, 'rb') as f: meta = pickle.load(f) meta_vocab_size = meta['vocab_size'] print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") # Model initialization if init_from == 'scratch': print(f"Initializing a new {model_type} model from scratch") if meta_vocab_size is None: print(f"defaulting to vocab_size of {vocab_size}") else: model_config.vocab_size = meta_vocab_size if model_type == 'mamba': #model = MambaLM(model_config) model = MambaLMHeadModel(model_config) else: model = GPT(model_config) if auto_clip: grad_clip = 0 config['grad_clip'] = 0 grad_norm_history = [] elif init_from == 'resume' or init_from == 'anneal': print(f"Resuming training from {out_dir}") if init_from == 'anneal': ckpt_path = os.path.join(out_dir, anneal_checkpoint) else: ckpt_path = os.path.join(out_dir, 'ckpt.pt') checkpoint = torch.load(ckpt_path, map_location=device) model_config = checkpoint['model_args'] if model_type == 'mamba': #model = MambaLM(model_config) model = MambaLMHeadModel(model_config) else: model = GPT(model_config) state_dict = checkpoint['model'] # fix the keys of the state dictionary :( # honestly no idea how checkpoints sometimes get this prefix, have to debug more unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) if 'effective_batch_size' not in checkpoint['config']: print("Checkpoint was saved without `effective_batch_size`, assuming current value (will save with next checkpoint). This is used for correcting `iter_num` when the effetive batch size is changed.") checkpoint['config']['effective_batch_size'] = effective_batch_size iter_num = int(round(checkpoint['iter_num'] * (checkpoint['config']['effective_batch_size'] / effective_batch_size))) if 'games_seen' in checkpoint: games_seen = checkpoint['games_seen'] else: games_seen = checkpoint['config']['effective_batch_size'] * checkpoint['iter_num'] checkpoint['games_seen'] = games_seen print(f"Checkpoint was saved without `games_seen`, assuming checkpoint's effective batch size * iters (will save with next checkpoint). {games_seen}") tokens_seen = checkpoint.get('tokens_seen', 0) tokens_seen_padded = checkpoint.get('tokens_seen_padded', 0) best_val_loss = checkpoint['best_val_loss'] print(f"Best val loss: {best_val_loss}") if auto_clip: grad_clip = checkpoint['config']['grad_clip'] config['grad_clip'] = grad_clip #grad_norm_history = [t.item() if torch.is_tensor(t) else t for t in checkpoint.get('grad_norm_history', [])] grad_norm_history = checkpoint.get('grad_norm_history', []) if init_from == 'anneal': print(f"\n\nANNEAL STARTING/RESUMING FROM ITERNUM: {iter_num} ({games_seen} games)\n\n") anneal_start_iters = iter_num if 'anneal_start_iters' not in checkpoint else checkpoint['anneal_start_iters'] anneal_decay_iters = iter_num / 8 if 'anneal_decay_iters' not in checkpoint else checkpoint['anneal_decay_iters'] # / 9 is og, but going deeper on lr too (can always take earlier ckpt during anneal if it doesn't keep improving)... have used 6.75 print(anneal_start_iters) print(anneal_decay_iters) if 'anneal_start_iters' not in checkpoint: grad_clip = 0 config['grad_clip'] = 0 grad_norm_history = [] print(f"Starting anneal. Resumed from {anneal_checkpoint}, will now decay learning rate for {anneal_decay_iters} / until iter_num {anneal_start_iters + anneal_decay_iters}.") out_dir = anneal_dir weight_decay = weight_decay / 12.5 # / 17.0 beta2 = np.sqrt(beta2) * beta2 auto_clip = True grad_clip_percentile = 6.75 elif init_from.startswith('state-spaces'): print(f"Initializing from Mamba pre-trained weights: {init_from}") model = from_pretrained(init_from) model_config = model.config else: raise ValueError("Invalid init_from value") model.to(device) print(f'Model with {sum([p.numel() for p in model.parameters()])} parameters loaded.') # Optimizer and GradScaler optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16') if init_from == 'resume': optimizer.load_state_dict(checkpoint['optimizer']) checkpoint = None # Compile the model if using PyTorch 2.0 if compile: print("compiling the model... (takes a ~minute)") model = torch.compile(model) # Wrap model in DDP container if necessary if ddp: model = DDP(model, device_ids=[ddp_local_rank]) def batch_to_loss(sequences, max_length_in_batch): if model_type == 'mamba': logits = model(sequences[:, :-1]).logits # Forward pass, exclude last token for input # Compute loss (assuming next token prediction task) targets = sequences[:, 1:].reshape(-1) # Shifted by one for next token prediction return F.cross_entropy(logits.view(-1, logits.size(-1)), targets) #return F.cross_entropy(logits.reshape(-1), targets) else: inputs = sequences[:, :-1] targets = sequences[:, 1:].reshape(-1) _, loss = model(inputs, targets) return loss @torch.no_grad() def estimate_loss(): global tokens_seen, tokens_seen_padded out = {} model.eval() tokens_seen_b4 = tokens_seen tokens_seen_padded_b4 = tokens_seen_padded for split in ['train']: #['train', 'val']: losses = torch.zeros(eval_iters) for k in range(eval_iters): loss = batch_to_loss(*get_batch(split)) losses[k] = loss.item() split = 'val' # Temporary hack out[split] = losses.mean() tokens_seen = tokens_seen_b4 tokens_seen_padded = tokens_seen_padded_b4 model.train() return out # WSD scheduler def get_lr(it): if init_from == 'anneal': # Linear decay from max LR to min LR over (anneal_start_iters / 9) iters decay_ratio = min(it - anneal_start_iters, anneal_decay_iters) / anneal_decay_iters return learning_rate - decay_ratio * (learning_rate - min_lr) if it < warmup_iters: # Warmup return learning_rate * it / warmup_iters # Stable max LR return learning_rate # Logging setup if wandb_log and master_process: import wandb wandb.init(project=wandb_project, name=wandb_run_name, config=config) # Training loop local_iter_num = 0 # Number of iterations in the lifetime of this process last_crossed_multiple = 0 save_every_n_games = 150000 raw_model = model.module if ddp else model # Unwrap DDP container if needed # initial save if init_from == 'scratch': checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': model_config, 'iter_num': 0, "games_seen": 0, "tokens_seen": 0, "tokens_seen_padded": 0, 'best_val_loss': best_val_loss, 'config': config, } checkpoint['grad_norm_history'] = grad_norm_history print(f"saving checkpoint to {out_dir}\n") torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) t0 = time.time() while True: # Determine and set the learning rate for this iteration lr = get_lr(iter_num) if decay_lr else learning_rate for param_group in optimizer.param_groups: param_group['lr'] = lr # Evaluate the loss on train/val sets and write checkpoints if iter_num % eval_interval == 0 and master_process and local_iter_num > 0: torch.cuda.empty_cache() losses = estimate_loss() if init_from == 'anneal': print(f"\ngame {games_seen} ({iter_num}, {(iter_num-anneal_start_iters) / anneal_decay_iters:.3%}): 'val' loss {losses['val']:.4f}") else: print(f"\ngame {games_seen} ({iter_num}, {iter_num / max_iters:.3%}): 'val' loss {losses['val']:.4f}") if auto_clip and len(grad_norm_history) >= grad_clip_start_size: grad_clip_prev = grad_clip grad_clip = np.percentile(grad_norm_history, grad_clip_percentile) grad_clip = max(min(grad_clip, auto_clip_max), auto_clip_min) # Transition between grad_clips smoothly, weighed to new value grad_clip = (grad_clip*9.0 + grad_clip_prev*4.0) / 13.0 grad_clip = max(min(grad_clip, auto_clip_max), auto_clip_min) # should never actually clip here config['grad_clip'] = grad_clip print(f"Auto adjusted grad_clip to {grad_clip}") torch.cuda.empty_cache() if wandb_log: wandb.log({ "etc/iter": iter_num, "etc/games": games_seen, "etc/tokens_seen": tokens_seen, "etc/tokens_seen_padded": tokens_seen_padded, "etc/grad_clip": grad_clip, "etc/lr": lr, "val/loss": losses['val'], }) if losses['val'] < best_val_loss or always_save_checkpoint: if iter_num > 0: checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': model_config, 'iter_num': iter_num, "games_seen": games_seen, "tokens_seen": tokens_seen, "tokens_seen_padded": tokens_seen_padded, 'best_val_loss': min(best_val_loss, losses['val']), 'config': config, } checkpoint['grad_norm_history'] = grad_norm_history if init_from == 'anneal': checkpoint['anneal_start_iters'] = anneal_start_iters checkpoint['anneal_decay_iters'] = anneal_decay_iters print(f"saving checkpoint to {out_dir}\n") torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) current_nearest_multiple = (games_seen // save_every_n_games) * save_every_n_games if losses['val'] < best_val_loss: # Temporary / only good after it's settled best_val_loss = losses['val'] torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt')) elif current_nearest_multiple != last_crossed_multiple: # elif so we don't double up last_crossed_multiple = current_nearest_multiple torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}.pt')) if iter_num == 0 and eval_only: break # Forward and backward pass for micro_step in range(gradient_accumulation_steps): if ddp: model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) sequences, max_length_in_batch = get_batch('train') # Fetch the training data with ctx: loss = batch_to_loss(sequences, max_length_in_batch) loss = loss / gradient_accumulation_steps scaler.scale(loss).backward() #print('.', end='') # clip the gradient if grad_clip != 0.0 or auto_clip: scaler.unscale_(optimizer) total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 999.9) # The 0 check is for auto_clip enabled but not enough history grad_norm_history.append(total_norm.item()) grad_norm_history = grad_norm_history[-grad_clip_max_size:] # step the optimizer and scaler if training in fp16 scaler.step(optimizer) scaler.update() # flush the gradients as soon as we can, no need for this memory anymore optimizer.zero_grad(set_to_none=True) #torch.cuda.empty_cache() # timing and logging t1 = time.time() dt = t1 - t0 t0 = t1 if iter_num % log_interval == 0 and master_process: # get loss as float. note: this is a CPU-GPU sync point # scale up to undo the division above, approximating the true total loss (exact would have been a sum) lossf = loss.item() * gradient_accumulation_steps if init_from == 'anneal': print(f"game {games_seen} ({iter_num}, {(iter_num-anneal_start_iters) / anneal_decay_iters:.3%}): loss {lossf:.4f}, time {dt*1000:.2f}ms") else: print(f"game {games_seen} ({iter_num}, {iter_num / max_iters:.3%}): loss {lossf:.4f}, time {dt*1000:.2f}ms") if wandb_log: wandb.log({ "etc/iter": iter_num, "etc/games": games_seen, "etc/tokens_seen": tokens_seen, "etc/tokens_seen_padded": tokens_seen_padded, "etc/grad_norm": grad_norm_history[-1] if grad_norm_history else 0, "etc/lr": lr, "train/loss": lossf, }) iter_num += 1 local_iter_num += 1 games_seen += effective_batch_size # termination conditions if iter_num > max_iters and not init_from == 'anneal': # max iters is for auto-stopping end of stable phase checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': model_config, 'iter_num': iter_num, "games_seen": games_seen, "tokens_seen": tokens_seen, "tokens_seen_padded": tokens_seen_padded, 'best_val_loss': best_val_loss, 'config': config, } checkpoint['grad_norm_history'] = grad_norm_history print(f"Max_iters reached. Saving pre-anneal checkpoint to {anneal_checkpoint}") torch.save(checkpoint, os.path.join(out_dir, anneal_checkpoint)) break if init_from == 'anneal' and iter_num >= anneal_start_iters + anneal_decay_iters: checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': model_config, 'iter_num': iter_num, "games_seen": games_seen, "tokens_seen": tokens_seen, "tokens_seen_padded": tokens_seen_padded, 'best_val_loss': best_val_loss, 'config': config, } checkpoint['grad_norm_history'] = grad_norm_history if init_from == 'anneal': checkpoint['anneal_start_iters'] = anneal_start_iters checkpoint['anneal_decay_iters'] = anneal_decay_iters print(f"Anneal complete. Saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, 'anneal_complete.pt')) break if ddp: destroy_process_group()