HaileyStorm commited on
Commit
ce494c6
·
verified ·
1 Parent(s): 26c20cd

Update chess-mamba-vs-xformer/train_bygame.py

Browse files
chess-mamba-vs-xformer/train_bygame.py CHANGED
@@ -402,7 +402,7 @@ while True:
402
  param_group['lr'] = lr
403
 
404
  # Evaluate the loss on train/val sets and write checkpoints
405
- if iter_num % eval_interval == 0 and master_process and local_iter_num > 0:
406
  torch.cuda.empty_cache()
407
  losses = estimate_loss()
408
  if init_from == 'anneal':
@@ -453,7 +453,7 @@ while True:
453
  if losses['val'] < best_val_loss: # Temporary / only good after it's settled
454
  best_val_loss = losses['val']
455
  torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
456
- elif current_nearest_multiple != last_crossed_multiple or abs(games_seen - 12652800) <= 151 or abs(games_seen - 22275000) <= 151 or abs(games_seen - 11510000) <= 151: # elif so we don't double up
457
  last_crossed_multiple = current_nearest_multiple
458
  torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}g_{tokens_seen_padded}t.pt'))
459
 
 
402
  param_group['lr'] = lr
403
 
404
  # Evaluate the loss on train/val sets and write checkpoints
405
+ if (iter_num % eval_interval == 0 and master_process and local_iter_num > 0) or abs(games_seen - 12652800) <= 151 or abs(games_seen - 22275000) <= 151 or abs(games_seen - 11536000) <= 151:
406
  torch.cuda.empty_cache()
407
  losses = estimate_loss()
408
  if init_from == 'anneal':
 
453
  if losses['val'] < best_val_loss: # Temporary / only good after it's settled
454
  best_val_loss = losses['val']
455
  torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
456
+ elif current_nearest_multiple != last_crossed_multiple or abs(games_seen - 12652800) <= 151 or abs(games_seen - 22275000) <= 151 or abs(games_seen - 11536000) <= 151: # elif so we don't double up
457
  last_crossed_multiple = current_nearest_multiple
458
  torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}g_{tokens_seen_padded}t.pt'))
459