HaileyStorm
commited on
Update chess-mamba-vs-xformer/train_bygame.py
Browse files
chess-mamba-vs-xformer/train_bygame.py
CHANGED
@@ -394,6 +394,9 @@ if init_from == 'scratch':
|
|
394 |
print(f"saving checkpoint to {out_dir}\n")
|
395 |
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
|
396 |
|
|
|
|
|
|
|
397 |
t0 = time.time()
|
398 |
while True:
|
399 |
# Determine and set the learning rate for this iteration
|
@@ -402,7 +405,7 @@ while True:
|
|
402 |
param_group['lr'] = lr
|
403 |
|
404 |
# Evaluate the loss on train/val sets and write checkpoints
|
405 |
-
if master_process and ((iter_num % eval_interval == 0 and local_iter_num > 0) or
|
406 |
torch.cuda.empty_cache()
|
407 |
losses = estimate_loss()
|
408 |
if init_from == 'anneal':
|
@@ -453,7 +456,7 @@ while True:
|
|
453 |
if losses['val'] < best_val_loss: # Temporary / only good after it's settled
|
454 |
best_val_loss = losses['val']
|
455 |
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
|
456 |
-
elif current_nearest_multiple != last_crossed_multiple or
|
457 |
last_crossed_multiple = current_nearest_multiple
|
458 |
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}g_{tokens_seen_padded}t.pt'))
|
459 |
|
|
|
394 |
print(f"saving checkpoint to {out_dir}\n")
|
395 |
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
|
396 |
|
397 |
+
GAMES_SEEN_CHECKPOINTS = [12652800, 22275000, 11536000, 16250000, 18000000, 19690000, 22005050]
|
398 |
+
TOKENS_SEEN_PADDED_CHECKPOINTS = [7798839804]
|
399 |
+
|
400 |
t0 = time.time()
|
401 |
while True:
|
402 |
# Determine and set the learning rate for this iteration
|
|
|
405 |
param_group['lr'] = lr
|
406 |
|
407 |
# Evaluate the loss on train/val sets and write checkpoints
|
408 |
+
if master_process and ((iter_num % eval_interval == 0 and local_iter_num > 0) or any(abs(games_seen - checkpoint) <= 151 for checkpoint in GAMES_SEEN_CHECKPOINTS) or any(abs(tokens_seen_padded - checkpoint) <= 46238 for checkpoint in TOKENS_SEEN_PADDED_CHECKPOINTS)):
|
409 |
torch.cuda.empty_cache()
|
410 |
losses = estimate_loss()
|
411 |
if init_from == 'anneal':
|
|
|
456 |
if losses['val'] < best_val_loss: # Temporary / only good after it's settled
|
457 |
best_val_loss = losses['val']
|
458 |
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
|
459 |
+
elif current_nearest_multiple != last_crossed_multiple or any(abs(games_seen - checkpoint) <= 151 for checkpoint in GAMES_SEEN_CHECKPOINTS) or any(abs(tokens_seen_padded - checkpoint) <= 46238 for checkpoint in TOKENS_SEEN_PADDED_CHECKPOINTS): # elif so we don't double up
|
460 |
last_crossed_multiple = current_nearest_multiple
|
461 |
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}g_{tokens_seen_padded}t.pt'))
|
462 |
|