MilesCranmer commited on
Commit
1e43bdc
1 Parent(s): 3f60d16

Fix PyTorch lightning total steps

Browse files
Files changed (1) hide show
  1. examples/pysr_demo.ipynb +1 -3
examples/pysr_demo.ipynb CHANGED
@@ -1118,14 +1118,12 @@
1118
  " return self.training_step(batch, batch_idx)\n",
1119
  "\n",
1120
  " def configure_optimizers(self):\n",
1121
- " self.trainer.reset_train_dataloader()\n",
1122
- "\n",
1123
  " optimizer = torch.optim.Adam(self.parameters(), lr=self.max_lr)\n",
1124
  " scheduler = {\n",
1125
  " \"scheduler\": torch.optim.lr_scheduler.OneCycleLR(\n",
1126
  " optimizer,\n",
1127
  " max_lr=self.max_lr,\n",
1128
- " total_steps=self.total_steps,\n",
1129
  " final_div_factor=1e4,\n",
1130
  " ),\n",
1131
  " \"interval\": \"step\",\n",
 
1118
  " return self.training_step(batch, batch_idx)\n",
1119
  "\n",
1120
  " def configure_optimizers(self):\n",
 
 
1121
  " optimizer = torch.optim.Adam(self.parameters(), lr=self.max_lr)\n",
1122
  " scheduler = {\n",
1123
  " \"scheduler\": torch.optim.lr_scheduler.OneCycleLR(\n",
1124
  " optimizer,\n",
1125
  " max_lr=self.max_lr,\n",
1126
+ " total_steps=self.trainer.estimated_stepping_batches,\n",
1127
  " final_div_factor=1e4,\n",
1128
  " ),\n",
1129
  " \"interval\": \"step\",\n",