HaileyStorm
commited on
Update chess-mamba-vs-xformer/train_bygame.py
Browse files
chess-mamba-vs-xformer/train_bygame.py
CHANGED
@@ -394,7 +394,10 @@ while True:
|
|
394 |
if iter_num % eval_interval == 0 and master_process and local_iter_num > 0:
|
395 |
torch.cuda.empty_cache()
|
396 |
losses = estimate_loss()
|
397 |
-
|
|
|
|
|
|
|
398 |
if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
|
399 |
grad_clip_prev = grad_clip
|
400 |
grad_clip = np.percentile(grad_norm_history, grad_clip_percentile)
|
@@ -481,7 +484,10 @@ while True:
|
|
481 |
# get loss as float. note: this is a CPU-GPU sync point
|
482 |
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
|
483 |
lossf = loss.item() * gradient_accumulation_steps
|
484 |
-
|
|
|
|
|
|
|
485 |
if wandb_log:
|
486 |
wandb.log({
|
487 |
"etc/iter": iter_num,
|
|
|
394 |
if iter_num % eval_interval == 0 and master_process and local_iter_num > 0:
|
395 |
torch.cuda.empty_cache()
|
396 |
losses = estimate_loss()
|
397 |
+
if init_from == 'anneal':
|
398 |
+
print(f"\ngame {games_seen} ({iter_num}, {(iter_num-anneal_start_iters) / anneal_decay_iters:.3%}): 'val' loss {losses['val']:.4f}")
|
399 |
+
else:
|
400 |
+
print(f"\ngame {games_seen} ({iter_num}, {iter_num / max_iters:.3%}): 'val' loss {losses['val']:.4f}")
|
401 |
if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
|
402 |
grad_clip_prev = grad_clip
|
403 |
grad_clip = np.percentile(grad_norm_history, grad_clip_percentile)
|
|
|
484 |
# get loss as float. note: this is a CPU-GPU sync point
|
485 |
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
|
486 |
lossf = loss.item() * gradient_accumulation_steps
|
487 |
+
if init_from == 'anneal':
|
488 |
+
print(f"game {games_seen} ({iter_num}, {(iter_num-anneal_start_iters) / anneal_decay_iters:.3%}): loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
489 |
+
else:
|
490 |
+
print(f"game {games_seen} ({iter_num}, {iter_num / max_iters:.3%}): loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
491 |
if wandb_log:
|
492 |
wandb.log({
|
493 |
"etc/iter": iter_num,
|