jbilcke-hf HF Staff commited on
Commit
ac45732
·
1 Parent(s): 74afeb1

fix the checkpoint recovery

Browse files
finetrainers/trainer/sft_trainer/trainer.py CHANGED
@@ -325,21 +325,8 @@ class SFTTrainer:
325
  resume_from_checkpoint = self.args.resume_from_checkpoint
326
  if resume_from_checkpoint == "latest":
327
  resume_from_checkpoint = -1
328
-
329
- # Store the load result
330
- load_successful = False
331
  if resume_from_checkpoint is not None:
332
- load_successful = self.checkpointer.load(resume_from_checkpoint)
333
-
334
- # If loading succeeded and we have a specific checkpoint path
335
- if load_successful and isinstance(resume_from_checkpoint, str) and resume_from_checkpoint != "latest":
336
- try:
337
- step = int(resume_from_checkpoint.split("_")[-1])
338
- self.state.train_state.step = step
339
- logger.info(f"Explicitly setting training step to {step} based on checkpoint path")
340
- except (ValueError, IndexError):
341
- logger.warning(f"Could not parse step number from checkpoint path: {resume_from_checkpoint}")
342
-
343
 
344
  def _train(self) -> None:
345
  logger.info("Starting training")
 
325
  resume_from_checkpoint = self.args.resume_from_checkpoint
326
  if resume_from_checkpoint == "latest":
327
  resume_from_checkpoint = -1
 
 
 
328
  if resume_from_checkpoint is not None:
329
+ self.checkpointer.load(resume_from_checkpoint)
 
 
 
 
 
 
 
 
 
 
330
 
331
  def _train(self) -> None:
332
  logger.info("Starting training")
vms/ui/project/services/training.py CHANGED
@@ -1097,6 +1097,11 @@ class TrainingService:
1097
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1098
  checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
1099
  logger.info(f"Found checkpoint at step {checkpoint_step}")
 
 
 
 
 
1100
  else:
1101
  logger.warning("No checkpoints found for recovery")
1102
  # Set buttons for no active training
 
1097
  latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
1098
  checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
1099
  logger.info(f"Found checkpoint at step {checkpoint_step}")
1100
+
1101
+ # both options are valid, but imho it is easier to just return "latest"
1102
+ # under the hood Finetrainers will convert ("latest") to (-1)
1103
+ #latest_checkpoint = int(checkpoint_step)
1104
+ latest_checkpoint = "latest"
1105
  else:
1106
  logger.warning("No checkpoints found for recovery")
1107
  # Set buttons for no active training