Spaces:
Running
Running
Commit
·
ac45732
1
Parent(s):
74afeb1
fix the checkpoint recovery
Browse files
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
@@ -325,21 +325,8 @@ class SFTTrainer:
|
|
325 |
resume_from_checkpoint = self.args.resume_from_checkpoint
|
326 |
if resume_from_checkpoint == "latest":
|
327 |
resume_from_checkpoint = -1
|
328 |
-
|
329 |
-
# Store the load result
|
330 |
-
load_successful = False
|
331 |
if resume_from_checkpoint is not None:
|
332 |
-
|
333 |
-
|
334 |
-
# If loading succeeded and we have a specific checkpoint path
|
335 |
-
if load_successful and isinstance(resume_from_checkpoint, str) and resume_from_checkpoint != "latest":
|
336 |
-
try:
|
337 |
-
step = int(resume_from_checkpoint.split("_")[-1])
|
338 |
-
self.state.train_state.step = step
|
339 |
-
logger.info(f"Explicitly setting training step to {step} based on checkpoint path")
|
340 |
-
except (ValueError, IndexError):
|
341 |
-
logger.warning(f"Could not parse step number from checkpoint path: {resume_from_checkpoint}")
|
342 |
-
|
343 |
|
344 |
def _train(self) -> None:
|
345 |
logger.info("Starting training")
|
|
|
325 |
resume_from_checkpoint = self.args.resume_from_checkpoint
|
326 |
if resume_from_checkpoint == "latest":
|
327 |
resume_from_checkpoint = -1
|
|
|
|
|
|
|
328 |
if resume_from_checkpoint is not None:
|
329 |
+
self.checkpointer.load(resume_from_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
def _train(self) -> None:
|
332 |
logger.info("Starting training")
|
vms/ui/project/services/training.py
CHANGED
@@ -1097,6 +1097,11 @@ class TrainingService:
|
|
1097 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1098 |
checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
|
1099 |
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
|
|
|
|
|
|
|
|
|
|
1100 |
else:
|
1101 |
logger.warning("No checkpoints found for recovery")
|
1102 |
# Set buttons for no active training
|
|
|
1097 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
1098 |
checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
|
1099 |
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
1100 |
+
|
1101 |
+
# both options are valid, but imho it is easier to just return "latest"
|
1102 |
+
# under the hood Finetrainers will convert ("latest") to (-1)
|
1103 |
+
#latest_checkpoint = int(checkpoint_step)
|
1104 |
+
latest_checkpoint = "latest"
|
1105 |
else:
|
1106 |
logger.warning("No checkpoints found for recovery")
|
1107 |
# Set buttons for no active training
|