Spaces:
Running
Running
Commit
·
54a2a4e
1
Parent(s):
9545589
working on training job failure recovery
Browse files- app.py +124 -4
- vms/training_log_parser.py +33 -34
- vms/training_service.py +188 -4
app.py
CHANGED
@@ -59,7 +59,43 @@ class VideoTrainerUI:
|
|
59 |
self.captioner = CaptioningService()
|
60 |
self._should_stop_captioning = False
|
61 |
self.log_parser = TrainingLogParser()
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def update_captioning_buttons_start(self):
|
64 |
"""Return individual button values instead of a dictionary"""
|
65 |
return (
|
@@ -1120,12 +1156,55 @@ class VideoTrainerUI:
|
|
1120 |
return gr.update(value=repo_id, error=None)
|
1121 |
|
1122 |
# Connect events
|
|
|
|
|
1123 |
model_type.change(
|
|
|
|
|
|
|
|
|
1124 |
fn=update_model_info,
|
1125 |
inputs=[model_type],
|
1126 |
outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
|
1127 |
)
|
1128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1129 |
async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
1130 |
videos = self.list_unprocessed_videos()
|
1131 |
# If scene detection isn't already running and there are videos to process,
|
@@ -1243,8 +1322,13 @@ class VideoTrainerUI:
|
|
1243 |
fn=self.list_training_files_to_caption,
|
1244 |
outputs=[training_dataset]
|
1245 |
)
|
1246 |
-
|
|
|
1247 |
training_preset.change(
|
|
|
|
|
|
|
|
|
1248 |
fn=self.update_training_params,
|
1249 |
inputs=[training_preset],
|
1250 |
outputs=[
|
@@ -1337,13 +1421,49 @@ class VideoTrainerUI:
|
|
1337 |
]
|
1338 |
)
|
1339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1340 |
# Auto-refresh timers
|
1341 |
app.load(
|
1342 |
fn=lambda: (
|
1343 |
-
self.refresh_dataset()
|
|
|
|
|
|
|
1344 |
),
|
1345 |
outputs=[
|
1346 |
-
video_list, training_dataset
|
|
|
|
|
|
|
|
|
1347 |
]
|
1348 |
)
|
1349 |
|
|
|
59 |
self.captioner = CaptioningService()
|
60 |
self._should_stop_captioning = False
|
61 |
self.log_parser = TrainingLogParser()
|
62 |
+
|
63 |
+
# Try to recover any interrupted training sessions
|
64 |
+
recovery_result = self.trainer.recover_interrupted_training()
|
65 |
+
|
66 |
+
self.recovery_status = recovery_result.get("status", "unknown")
|
67 |
+
self.ui_updates = recovery_result.get("ui_updates", {})
|
68 |
+
|
69 |
+
if recovery_result["status"] == "recovered":
|
70 |
+
logger.info(f"Training recovery: {recovery_result['message']}")
|
71 |
+
# No need to do anything else - the training is already running
|
72 |
+
elif recovery_result["status"] == "running":
|
73 |
+
logger.info("Training process is already running")
|
74 |
+
# No need to do anything - the process is still alive
|
75 |
+
elif recovery_result["status"] in ["error", "idle"]:
|
76 |
+
logger.warning(f"Training status: {recovery_result['message']}")
|
77 |
+
# UI will be in ready-to-start mode
|
78 |
+
|
79 |
+
|
80 |
+
def update_ui_state(self, **kwargs):
|
81 |
+
"""Update UI state with new values"""
|
82 |
+
current_state = self.trainer.load_ui_state()
|
83 |
+
current_state.update(kwargs)
|
84 |
+
self.trainer.save_ui_state(current_state)
|
85 |
+
return current_state
|
86 |
+
|
87 |
+
def load_ui_values(self):
|
88 |
+
"""Load UI state values for initializing form fields"""
|
89 |
+
ui_state = self.trainer.load_ui_state()
|
90 |
+
|
91 |
+
# Convert types as needed since JSON stores everything as strings
|
92 |
+
ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
|
93 |
+
ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
|
94 |
+
ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
|
95 |
+
ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
|
96 |
+
|
97 |
+
return ui_state
|
98 |
+
|
99 |
def update_captioning_buttons_start(self):
|
100 |
"""Return individual button values instead of a dictionary"""
|
101 |
return (
|
|
|
1156 |
return gr.update(value=repo_id, error=None)
|
1157 |
|
1158 |
# Connect events
|
1159 |
+
|
1160 |
+
# Save state when model type changes
|
1161 |
model_type.change(
|
1162 |
+
fn=lambda v: self.update_ui_state(model_type=v),
|
1163 |
+
inputs=[model_type],
|
1164 |
+
outputs=[] # No UI update needed
|
1165 |
+
).then(
|
1166 |
fn=update_model_info,
|
1167 |
inputs=[model_type],
|
1168 |
outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
|
1169 |
)
|
1170 |
|
1171 |
+
# the following change listeners are used for UI persistence
|
1172 |
+
lora_rank.change(
|
1173 |
+
fn=lambda v: self.update_ui_state(lora_rank=v),
|
1174 |
+
inputs=[lora_rank],
|
1175 |
+
outputs=[]
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
lora_alpha.change(
|
1179 |
+
fn=lambda v: self.update_ui_state(lora_alpha=v),
|
1180 |
+
inputs=[lora_alpha],
|
1181 |
+
outputs=[]
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
num_epochs.change(
|
1185 |
+
fn=lambda v: self.update_ui_state(num_epochs=v),
|
1186 |
+
inputs=[num_epochs],
|
1187 |
+
outputs=[]
|
1188 |
+
)
|
1189 |
+
|
1190 |
+
batch_size.change(
|
1191 |
+
fn=lambda v: self.update_ui_state(batch_size=v),
|
1192 |
+
inputs=[batch_size],
|
1193 |
+
outputs=[]
|
1194 |
+
)
|
1195 |
+
|
1196 |
+
learning_rate.change(
|
1197 |
+
fn=lambda v: self.update_ui_state(learning_rate=v),
|
1198 |
+
inputs=[learning_rate],
|
1199 |
+
outputs=[]
|
1200 |
+
)
|
1201 |
+
|
1202 |
+
save_iterations.change(
|
1203 |
+
fn=lambda v: self.update_ui_state(save_iterations=v),
|
1204 |
+
inputs=[save_iterations],
|
1205 |
+
outputs=[]
|
1206 |
+
)
|
1207 |
+
|
1208 |
async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
1209 |
videos = self.list_unprocessed_videos()
|
1210 |
# If scene detection isn't already running and there are videos to process,
|
|
|
1322 |
fn=self.list_training_files_to_caption,
|
1323 |
outputs=[training_dataset]
|
1324 |
)
|
1325 |
+
|
1326 |
+
# Save state when training preset changes
|
1327 |
training_preset.change(
|
1328 |
+
fn=lambda v: self.update_ui_state(training_preset=v),
|
1329 |
+
inputs=[training_preset],
|
1330 |
+
outputs=[] # No UI update needed
|
1331 |
+
).then(
|
1332 |
fn=self.update_training_params,
|
1333 |
inputs=[training_preset],
|
1334 |
outputs=[
|
|
|
1421 |
]
|
1422 |
)
|
1423 |
|
1424 |
+
# Add this new method to get initial button states:
|
1425 |
+
def get_initial_button_states(self):
|
1426 |
+
"""Get the initial states for training buttons based on recovery status"""
|
1427 |
+
recovery_result = self.trainer.recover_interrupted_training()
|
1428 |
+
ui_updates = recovery_result.get("ui_updates", {})
|
1429 |
+
|
1430 |
+
# Return button states in the correct order
|
1431 |
+
return (
|
1432 |
+
gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
|
1433 |
+
gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
|
1434 |
+
gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
|
1435 |
+
)
|
1436 |
+
|
1437 |
+
def initialize_ui_from_state(self):
|
1438 |
+
"""Initialize UI components from saved state"""
|
1439 |
+
ui_state = self.load_ui_values()
|
1440 |
+
|
1441 |
+
# Return values in order matching the outputs in app.load
|
1442 |
+
return (
|
1443 |
+
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
1444 |
+
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
1445 |
+
ui_state.get("lora_rank", "128"),
|
1446 |
+
ui_state.get("lora_alpha", "128"),
|
1447 |
+
ui_state.get("num_epochs", 70),
|
1448 |
+
ui_state.get("batch_size", 1),
|
1449 |
+
ui_state.get("learning_rate", 3e-5),
|
1450 |
+
ui_state.get("save_iterations", 500)
|
1451 |
+
)
|
1452 |
+
|
1453 |
# Auto-refresh timers
|
1454 |
app.load(
|
1455 |
fn=lambda: (
|
1456 |
+
self.refresh_dataset(),
|
1457 |
+
*self.get_initial_button_states(),
|
1458 |
+
# Load saved UI state values
|
1459 |
+
*self.initialize_ui_from_state()
|
1460 |
),
|
1461 |
outputs=[
|
1462 |
+
video_list, training_dataset,
|
1463 |
+
start_btn, stop_btn, pause_resume_btn,
|
1464 |
+
# Add outputs for UI fields
|
1465 |
+
training_preset, model_type, lora_rank, lora_alpha,
|
1466 |
+
num_epochs, batch_size, learning_rate, save_iterations
|
1467 |
]
|
1468 |
)
|
1469 |
|
vms/training_log_parser.py
CHANGED
@@ -34,7 +34,14 @@ class TrainingState:
|
|
34 |
|
35 |
def to_dict(self) -> Dict[str, Any]:
|
36 |
"""Convert state to dictionary for UI updates"""
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..."
|
39 |
|
40 |
return {
|
@@ -74,10 +81,11 @@ class TrainingLogParser:
|
|
74 |
if ("Started training" in line) or ("Starting training" in line):
|
75 |
self.state.status = "training"
|
76 |
|
|
|
77 |
if "Training steps:" in line:
|
78 |
# Set status to training if we see this
|
79 |
self.state.status = "training"
|
80 |
-
|
81 |
if not self.state.start_time:
|
82 |
self.state.start_time = datetime.now()
|
83 |
|
@@ -97,36 +105,23 @@ class TrainingLogParser:
|
|
97 |
if match:
|
98 |
setattr(self.state, attr, float(match.group(1)))
|
99 |
|
100 |
-
#
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
# Create formatted timedelta
|
118 |
-
if days > 0:
|
119 |
-
formatted_time = f"{days}d {hours}h {minutes}m {seconds}s"
|
120 |
-
elif hours > 0:
|
121 |
-
formatted_time = f"{hours}h {minutes}m {seconds}s"
|
122 |
-
elif minutes > 0:
|
123 |
-
formatted_time = f"{minutes}m {seconds}s"
|
124 |
-
else:
|
125 |
-
formatted_time = f"{seconds}s"
|
126 |
-
|
127 |
-
self.state.estimated_remaining = formatted_time
|
128 |
-
self.state.last_step_time = now
|
129 |
-
|
130 |
logger.info(f"Updated training state: step={self.state.current_step}/{self.state.total_steps}, loss={self.state.step_loss}")
|
131 |
return self.state.to_dict()
|
132 |
|
@@ -162,12 +157,16 @@ class TrainingLogParser:
|
|
162 |
|
163 |
# Completion states
|
164 |
if "Training completed successfully" in line:
|
165 |
-
self.
|
|
|
|
|
166 |
logger.info("Training completed")
|
167 |
return self.state.to_dict()
|
168 |
|
169 |
if any(x in line for x in ["Training process stopped", "Training stopped"]):
|
170 |
-
self.
|
|
|
|
|
171 |
logger.info("Training stopped")
|
172 |
return self.state.to_dict()
|
173 |
|
|
|
34 |
|
35 |
def to_dict(self) -> Dict[str, Any]:
|
36 |
"""Convert state to dictionary for UI updates"""
|
37 |
+
# Calculate elapsed time only if training is active and we have a start time
|
38 |
+
if self.start_time and self.status in ["training", "initializing"]:
|
39 |
+
elapsed = str(datetime.now() - self.start_time)
|
40 |
+
else:
|
41 |
+
# Use the last known elapsed time or show 0
|
42 |
+
elapsed = "0:00:00" if not self.last_step_time else str(self.last_step_time - self.start_time if self.start_time else "0:00:00")
|
43 |
+
|
44 |
+
# Use precomputed remaining time from logs if available
|
45 |
remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..."
|
46 |
|
47 |
return {
|
|
|
81 |
if ("Started training" in line) or ("Starting training" in line):
|
82 |
self.state.status = "training"
|
83 |
|
84 |
+
# Check for "Training steps:" which contains the progress information
|
85 |
if "Training steps:" in line:
|
86 |
# Set status to training if we see this
|
87 |
self.state.status = "training"
|
88 |
+
|
89 |
if not self.state.start_time:
|
90 |
self.state.start_time = datetime.now()
|
91 |
|
|
|
105 |
if match:
|
106 |
setattr(self.state, attr, float(match.group(1)))
|
107 |
|
108 |
+
# Extract time remaining directly from the log
|
109 |
+
# Format: [MM:SS<M:SS:SS, SS.SSs/it]
|
110 |
+
time_remaining_match = re.search(r"<(\d+:\d+:\d+)", line)
|
111 |
+
if time_remaining_match:
|
112 |
+
remaining_str = time_remaining_match.group(1)
|
113 |
+
# Store the string directly - no need to parse it
|
114 |
+
self.state.estimated_remaining = remaining_str
|
115 |
+
|
116 |
+
# If no direct time estimate, look for hour:min format
|
117 |
+
if not time_remaining_match:
|
118 |
+
hour_min_match = re.search(r"<(\d+h\s*\d+m)", line)
|
119 |
+
if hour_min_match:
|
120 |
+
self.state.estimated_remaining = hour_min_match.group(1)
|
121 |
+
|
122 |
+
# Update last processing time
|
123 |
+
self.state.last_step_time = datetime.now()
|
124 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
logger.info(f"Updated training state: step={self.state.current_step}/{self.state.total_steps}, loss={self.state.step_loss}")
|
126 |
return self.state.to_dict()
|
127 |
|
|
|
157 |
|
158 |
# Completion states
|
159 |
if "Training completed successfully" in line:
|
160 |
+
self.status = "completed"
|
161 |
+
# Store final elapsed time
|
162 |
+
self.last_step_time = datetime.now()
|
163 |
logger.info("Training completed")
|
164 |
return self.state.to_dict()
|
165 |
|
166 |
if any(x in line for x in ["Training process stopped", "Training stopped"]):
|
167 |
+
self.status = "stopped"
|
168 |
+
# Store final elapsed time
|
169 |
+
self.last_step_time = datetime.now()
|
170 |
logger.info("Training stopped")
|
171 |
return self.state.to_dict()
|
172 |
|
vms/training_service.py
CHANGED
@@ -38,7 +38,7 @@ class TrainingService:
|
|
38 |
self.setup_logging()
|
39 |
|
40 |
logger.info("Training service initialized")
|
41 |
-
|
42 |
def setup_logging(self):
|
43 |
"""Set up logging with proper handler management"""
|
44 |
global logger
|
@@ -96,16 +96,58 @@ class TrainingService:
|
|
96 |
if self.file_handler:
|
97 |
self.file_handler.close()
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def save_session(self, params: Dict) -> None:
|
100 |
"""Save training session parameters"""
|
101 |
session_data = {
|
102 |
"timestamp": datetime.now().isoformat(),
|
103 |
"params": params,
|
104 |
-
"status": self.get_status()
|
|
|
|
|
105 |
}
|
106 |
with open(self.session_file, 'w') as f:
|
107 |
json.dump(session_data, f, indent=2)
|
108 |
-
|
109 |
def load_session(self) -> Optional[Dict]:
|
110 |
"""Load saved training session"""
|
111 |
if self.session_file.exists():
|
@@ -225,6 +267,7 @@ class TrainingService:
|
|
225 |
save_iterations: int,
|
226 |
repo_id: str,
|
227 |
preset_name: str,
|
|
|
228 |
) -> Tuple[str, str]:
|
229 |
"""Start training with finetrainers"""
|
230 |
|
@@ -295,6 +338,11 @@ class TrainingService:
|
|
295 |
config.lr = float(learning_rate)
|
296 |
config.checkpointing_steps = int(save_iterations)
|
297 |
|
|
|
|
|
|
|
|
|
|
|
298 |
# Common settings for both models
|
299 |
config.mixed_precision = "bf16"
|
300 |
config.seed = 42
|
@@ -477,10 +525,146 @@ class TrainingService:
|
|
477 |
try:
|
478 |
with open(self.pid_file, 'r') as f:
|
479 |
pid = int(f.read().strip())
|
480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
except:
|
482 |
return False
|
483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
484 |
def clear_training_data(self) -> str:
|
485 |
"""Clear all training data"""
|
486 |
if self.is_training_running():
|
|
|
38 |
self.setup_logging()
|
39 |
|
40 |
logger.info("Training service initialized")
|
41 |
+
|
42 |
def setup_logging(self):
|
43 |
"""Set up logging with proper handler management"""
|
44 |
global logger
|
|
|
96 |
if self.file_handler:
|
97 |
self.file_handler.close()
|
98 |
|
99 |
+
|
100 |
+
def save_ui_state(self, values: Dict[str, Any]) -> None:
|
101 |
+
"""Save current UI state to file"""
|
102 |
+
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
103 |
+
try:
|
104 |
+
with open(ui_state_file, 'w') as f:
|
105 |
+
json.dump(values, f, indent=2)
|
106 |
+
logger.debug(f"UI state saved: {values}")
|
107 |
+
except Exception as e:
|
108 |
+
logger.error(f"Error saving UI state: {str(e)}")
|
109 |
+
|
110 |
+
def load_ui_state(self) -> Dict[str, Any]:
|
111 |
+
"""Load saved UI state"""
|
112 |
+
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
113 |
+
default_state = {
|
114 |
+
"model_type": list(MODEL_TYPES.keys())[0],
|
115 |
+
"lora_rank": "128",
|
116 |
+
"lora_alpha": "128",
|
117 |
+
"num_epochs": 70,
|
118 |
+
"batch_size": 1,
|
119 |
+
"learning_rate": 3e-5,
|
120 |
+
"save_iterations": 500,
|
121 |
+
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
122 |
+
}
|
123 |
+
|
124 |
+
if not ui_state_file.exists():
|
125 |
+
return default_state
|
126 |
+
|
127 |
+
try:
|
128 |
+
with open(ui_state_file, 'r') as f:
|
129 |
+
saved_state = json.load(f)
|
130 |
+
# Make sure we have all keys (in case structure changed)
|
131 |
+
merged_state = default_state.copy()
|
132 |
+
merged_state.update(saved_state)
|
133 |
+
return merged_state
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(f"Error loading UI state: {str(e)}")
|
136 |
+
return default_state
|
137 |
+
|
138 |
+
# Modify save_session to also store the UI state at training start
|
139 |
def save_session(self, params: Dict) -> None:
|
140 |
"""Save training session parameters"""
|
141 |
session_data = {
|
142 |
"timestamp": datetime.now().isoformat(),
|
143 |
"params": params,
|
144 |
+
"status": self.get_status(),
|
145 |
+
# Add UI state at the time training started
|
146 |
+
"initial_ui_state": self.load_ui_state()
|
147 |
}
|
148 |
with open(self.session_file, 'w') as f:
|
149 |
json.dump(session_data, f, indent=2)
|
150 |
+
|
151 |
def load_session(self) -> Optional[Dict]:
|
152 |
"""Load saved training session"""
|
153 |
if self.session_file.exists():
|
|
|
267 |
save_iterations: int,
|
268 |
repo_id: str,
|
269 |
preset_name: str,
|
270 |
+
resume_from_checkpoint: Optional[str] = None,
|
271 |
) -> Tuple[str, str]:
|
272 |
"""Start training with finetrainers"""
|
273 |
|
|
|
338 |
config.lr = float(learning_rate)
|
339 |
config.checkpointing_steps = int(save_iterations)
|
340 |
|
341 |
+
# Update with resume_from_checkpoint if provided
|
342 |
+
if resume_from_checkpoint:
|
343 |
+
config.resume_from_checkpoint = resume_from_checkpoint
|
344 |
+
self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
|
345 |
+
|
346 |
# Common settings for both models
|
347 |
config.mixed_precision = "bf16"
|
348 |
config.seed = 42
|
|
|
525 |
try:
|
526 |
with open(self.pid_file, 'r') as f:
|
527 |
pid = int(f.read().strip())
|
528 |
+
|
529 |
+
# Check if process exists AND is a Python process running train.py
|
530 |
+
if psutil.pid_exists(pid):
|
531 |
+
try:
|
532 |
+
process = psutil.Process(pid)
|
533 |
+
cmdline = process.cmdline()
|
534 |
+
# Check if it's a Python process running train.py
|
535 |
+
return any('train.py' in cmd for cmd in cmdline)
|
536 |
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
537 |
+
return False
|
538 |
+
return False
|
539 |
except:
|
540 |
return False
|
541 |
|
542 |
+
def recover_interrupted_training(self) -> Dict[str, Any]:
|
543 |
+
"""Attempt to recover interrupted training
|
544 |
+
|
545 |
+
Returns:
|
546 |
+
Dict with recovery status and UI updates
|
547 |
+
"""
|
548 |
+
status = self.get_status()
|
549 |
+
ui_updates = {}
|
550 |
+
|
551 |
+
# If status indicates training but process isn't running, try to recover
|
552 |
+
if status.get('status') == 'training' and not self.is_training_running():
|
553 |
+
logger.info("Detected interrupted training session, attempting to recover...")
|
554 |
+
|
555 |
+
# Get the latest checkpoint
|
556 |
+
last_session = self.load_session()
|
557 |
+
if not last_session:
|
558 |
+
logger.warning("No session data found for recovery")
|
559 |
+
# Set buttons for no active training
|
560 |
+
ui_updates = {
|
561 |
+
"start_btn": {"interactive": True, "variant": "primary"},
|
562 |
+
"stop_btn": {"interactive": False, "variant": "secondary"},
|
563 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
564 |
+
}
|
565 |
+
return {"status": "error", "message": "No session data found", "ui_updates": ui_updates}
|
566 |
+
|
567 |
+
# Find the latest checkpoint
|
568 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
569 |
+
if not checkpoints:
|
570 |
+
logger.warning("No checkpoints found for recovery")
|
571 |
+
# Set buttons for no active training
|
572 |
+
ui_updates = {
|
573 |
+
"start_btn": {"interactive": True, "variant": "primary"},
|
574 |
+
"stop_btn": {"interactive": False, "variant": "secondary"},
|
575 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
576 |
+
}
|
577 |
+
return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
|
578 |
+
|
579 |
+
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
580 |
+
checkpoint_step = int(latest_checkpoint.name.split("-")[1])
|
581 |
+
|
582 |
+
logger.info(f"Found checkpoint at step {checkpoint_step}, attempting to resume")
|
583 |
+
|
584 |
+
# Extract parameters from the saved session (not current UI state)
|
585 |
+
# This ensures we use the original training parameters
|
586 |
+
params = last_session.get('params', {})
|
587 |
+
initial_ui_state = last_session.get('initial_ui_state', {})
|
588 |
+
|
589 |
+
# Add UI updates to restore the training parameters in the UI
|
590 |
+
# This shows the user what values are being used for the resumed training
|
591 |
+
ui_updates.update({
|
592 |
+
"model_type": gr.update(value=params.get('model_type', list(MODEL_TYPES.keys())[0])),
|
593 |
+
"lora_rank": gr.update(value=params.get('lora_rank', "128")),
|
594 |
+
"lora_alpha": gr.update(value=params.get('lora_alpha', "128")),
|
595 |
+
"num_epochs": gr.update(value=params.get('num_epochs', 70)),
|
596 |
+
"batch_size": gr.update(value=params.get('batch_size', 1)),
|
597 |
+
"learning_rate": gr.update(value=params.get('learning_rate', 3e-5)),
|
598 |
+
"save_iterations": gr.update(value=params.get('save_iterations', 500)),
|
599 |
+
"training_preset": gr.update(value=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]))
|
600 |
+
})
|
601 |
+
|
602 |
+
# Attempt to resume training using the ORIGINAL parameters
|
603 |
+
try:
|
604 |
+
# Extract required parameters from the session
|
605 |
+
model_type = params.get('model_type')
|
606 |
+
lora_rank = params.get('lora_rank')
|
607 |
+
lora_alpha = params.get('lora_alpha')
|
608 |
+
num_epochs = params.get('num_epochs')
|
609 |
+
batch_size = params.get('batch_size')
|
610 |
+
learning_rate = params.get('learning_rate')
|
611 |
+
save_iterations = params.get('save_iterations')
|
612 |
+
repo_id = params.get('repo_id')
|
613 |
+
preset_name = params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
614 |
+
|
615 |
+
# Attempt to resume training
|
616 |
+
result = self.start_training(
|
617 |
+
model_type=model_type,
|
618 |
+
lora_rank=lora_rank,
|
619 |
+
lora_alpha=lora_alpha,
|
620 |
+
num_epochs=num_epochs,
|
621 |
+
batch_size=batch_size,
|
622 |
+
learning_rate=learning_rate,
|
623 |
+
save_iterations=save_iterations,
|
624 |
+
repo_id=repo_id,
|
625 |
+
preset_name=preset_name,
|
626 |
+
resume_from_checkpoint=str(latest_checkpoint)
|
627 |
+
)
|
628 |
+
|
629 |
+
# Set buttons for active training
|
630 |
+
ui_updates.update({
|
631 |
+
"start_btn": {"interactive": False, "variant": "secondary"},
|
632 |
+
"stop_btn": {"interactive": True, "variant": "stop"},
|
633 |
+
"pause_resume_btn": {"interactive": True, "variant": "secondary"}
|
634 |
+
})
|
635 |
+
|
636 |
+
return {
|
637 |
+
"status": "recovered",
|
638 |
+
"message": f"Training resumed from checkpoint {checkpoint_step}",
|
639 |
+
"result": result,
|
640 |
+
"ui_updates": ui_updates
|
641 |
+
}
|
642 |
+
except Exception as e:
|
643 |
+
logger.error(f"Failed to resume training: {str(e)}")
|
644 |
+
# Set buttons for no active training
|
645 |
+
ui_updates.update({
|
646 |
+
"start_btn": {"interactive": True, "variant": "primary"},
|
647 |
+
"stop_btn": {"interactive": False, "variant": "secondary"},
|
648 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
649 |
+
})
|
650 |
+
return {"status": "error", "message": f"Failed to resume: {str(e)}", "ui_updates": ui_updates}
|
651 |
+
elif self.is_training_running():
|
652 |
+
# Process is still running, set buttons accordingly
|
653 |
+
ui_updates = {
|
654 |
+
"start_btn": {"interactive": False, "variant": "secondary"},
|
655 |
+
"stop_btn": {"interactive": True, "variant": "stop"},
|
656 |
+
"pause_resume_btn": {"interactive": True, "variant": "secondary"}
|
657 |
+
}
|
658 |
+
return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
|
659 |
+
else:
|
660 |
+
# No training process, set buttons to default state
|
661 |
+
ui_updates = {
|
662 |
+
"start_btn": {"interactive": True, "variant": "primary"},
|
663 |
+
"stop_btn": {"interactive": False, "variant": "secondary"},
|
664 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
665 |
+
}
|
666 |
+
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
667 |
+
|
668 |
def clear_training_data(self) -> str:
|
669 |
"""Clear all training data"""
|
670 |
if self.is_training_running():
|