jbilcke-hf HF Staff commited on
Commit
446e79f
·
1 Parent(s): 54a2a4e

working on fixes

Browse files
Files changed (2) hide show
  1. app.py +72 -37
  2. vms/training_service.py +10 -5
app.py CHANGED
@@ -77,12 +77,68 @@ class VideoTrainerUI:
77
  # UI will be in ready-to-start mode
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def update_ui_state(self, **kwargs):
81
  """Update UI state with new values"""
82
  current_state = self.trainer.load_ui_state()
83
  current_state.update(kwargs)
84
  self.trainer.save_ui_state(current_state)
85
- return current_state
 
86
 
87
  def load_ui_values(self):
88
  """Load UI state values for initializing form fields"""
@@ -130,6 +186,19 @@ class VideoTrainerUI:
130
  )
131
  )
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def show_refreshing_status(self) -> List[List[str]]:
134
  """Show a 'Refreshing...' status in the dataframe"""
135
  return [["Refreshing...", "please wait"]]
@@ -1421,52 +1490,18 @@ class VideoTrainerUI:
1421
  ]
1422
  )
1423
 
1424
- # Add this new method to get initial button states:
1425
- def get_initial_button_states(self):
1426
- """Get the initial states for training buttons based on recovery status"""
1427
- recovery_result = self.trainer.recover_interrupted_training()
1428
- ui_updates = recovery_result.get("ui_updates", {})
1429
-
1430
- # Return button states in the correct order
1431
- return (
1432
- gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
1433
- gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
1434
- gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
1435
- )
1436
-
1437
- def initialize_ui_from_state(self):
1438
- """Initialize UI components from saved state"""
1439
- ui_state = self.load_ui_values()
1440
-
1441
- # Return values in order matching the outputs in app.load
1442
- return (
1443
- ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
1444
- ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
1445
- ui_state.get("lora_rank", "128"),
1446
- ui_state.get("lora_alpha", "128"),
1447
- ui_state.get("num_epochs", 70),
1448
- ui_state.get("batch_size", 1),
1449
- ui_state.get("learning_rate", 3e-5),
1450
- ui_state.get("save_iterations", 500)
1451
- )
1452
 
1453
- # Auto-refresh timers
1454
  app.load(
1455
- fn=lambda: (
1456
- self.refresh_dataset(),
1457
- *self.get_initial_button_states(),
1458
- # Load saved UI state values
1459
- *self.initialize_ui_from_state()
1460
- ),
1461
  outputs=[
1462
  video_list, training_dataset,
1463
  start_btn, stop_btn, pause_resume_btn,
1464
- # Add outputs for UI fields
1465
  training_preset, model_type, lora_rank, lora_alpha,
1466
  num_epochs, batch_size, learning_rate, save_iterations
1467
  ]
1468
  )
1469
 
 
1470
  timer = gr.Timer(value=1)
1471
  timer.tick(
1472
  fn=lambda: (
 
77
  # UI will be in ready-to-start mode
78
 
79
 
80
+ def initialize_app_state(self):
81
+ """Initialize all app state in one function to ensure correct output count"""
82
+ # Get dataset info
83
+ video_list, training_dataset = self.refresh_dataset()
84
+
85
+ # Get button states
86
+ button_states = self.get_initial_button_states()
87
+ start_btn = button_states[0]
88
+ stop_btn = button_states[1]
89
+ pause_resume_btn = button_states[2]
90
+
91
+ # Get UI form values
92
+ ui_state = self.load_ui_values()
93
+ training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
94
+ model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
95
+ lora_rank_val = ui_state.get("lora_rank", "128")
96
+ lora_alpha_val = ui_state.get("lora_alpha", "128")
97
+ num_epochs_val = int(ui_state.get("num_epochs", 70))
98
+ batch_size_val = int(ui_state.get("batch_size", 1))
99
+ learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
100
+ save_iterations_val = int(ui_state.get("save_iterations", 500))
101
+
102
+ # Return all values in the exact order expected by outputs
103
+ return (
104
+ video_list,
105
+ training_dataset,
106
+ start_btn,
107
+ stop_btn,
108
+ pause_resume_btn,
109
+ training_preset,
110
+ model_type_val,
111
+ lora_rank_val,
112
+ lora_alpha_val,
113
+ num_epochs_val,
114
+ batch_size_val,
115
+ learning_rate_val,
116
+ save_iterations_val
117
+ )
118
+
119
+ def initialize_ui_from_state(self):
120
+ """Initialize UI components from saved state"""
121
+ ui_state = self.load_ui_values()
122
+
123
+ # Return values in order matching the outputs in app.load
124
+ return (
125
+ ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
126
+ ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
127
+ ui_state.get("lora_rank", "128"),
128
+ ui_state.get("lora_alpha", "128"),
129
+ ui_state.get("num_epochs", 70),
130
+ ui_state.get("batch_size", 1),
131
+ ui_state.get("learning_rate", 3e-5),
132
+ ui_state.get("save_iterations", 500)
133
+ )
134
+
135
  def update_ui_state(self, **kwargs):
136
  """Update UI state with new values"""
137
  current_state = self.trainer.load_ui_state()
138
  current_state.update(kwargs)
139
  self.trainer.save_ui_state(current_state)
140
+ # Don't return anything to avoid Gradio warnings
141
+ return None
142
 
143
  def load_ui_values(self):
144
  """Load UI state values for initializing form fields"""
 
186
  )
187
  )
188
 
189
+ # Add this new method to get initial button states:
190
+ def get_initial_button_states(self):
191
+ """Get the initial states for training buttons based on recovery status"""
192
+ recovery_result = self.trainer.recover_interrupted_training()
193
+ ui_updates = recovery_result.get("ui_updates", {})
194
+
195
+ # Return button states in the correct order
196
+ return (
197
+ gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
198
+ gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
199
+ gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
200
+ )
201
+
202
  def show_refreshing_status(self) -> List[List[str]]:
203
  """Show a 'Refreshing...' status in the dataframe"""
204
  return [["Refreshing...", "please wait"]]
 
1490
  ]
1491
  )
1492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1493
 
 
1494
  app.load(
1495
+ fn=self.initialize_app_state,
 
 
 
 
 
1496
  outputs=[
1497
  video_list, training_dataset,
1498
  start_btn, stop_btn, pause_resume_btn,
 
1499
  training_preset, model_type, lora_rank, lora_alpha,
1500
  num_epochs, batch_size, learning_rate, save_iterations
1501
  ]
1502
  )
1503
 
1504
+ # Auto-refresh timers
1505
  timer = gr.Timer(value=1)
1506
  timer.tick(
1507
  fn=lambda: (
vms/training_service.py CHANGED
@@ -164,12 +164,11 @@ class TrainingService:
164
 
165
  if not self.status_file.exists():
166
  return default_status
167
-
168
  try:
169
  with open(self.status_file, 'r') as f:
170
  status = json.load(f)
171
- #print("status found in the json:", status)
172
-
173
  # Check if process is actually running
174
  if self.pid_file.exists():
175
  with open(self.pid_file, 'r') as f:
@@ -177,14 +176,20 @@ class TrainingService:
177
  if not psutil.pid_exists(pid):
178
  # Process died unexpectedly
179
  if status['status'] == 'training':
 
 
 
 
180
  status['status'] = 'error'
181
  status['message'] = 'Training process terminated unexpectedly'
182
- self.append_log("Training process terminated unexpectedly")
 
 
183
  else:
184
  status['status'] = 'stopped'
185
  status['message'] = 'Training process not found'
186
  return status
187
-
188
  except (json.JSONDecodeError, ValueError):
189
  return default_status
190
 
 
164
 
165
  if not self.status_file.exists():
166
  return default_status
167
+
168
  try:
169
  with open(self.status_file, 'r') as f:
170
  status = json.load(f)
171
+
 
172
  # Check if process is actually running
173
  if self.pid_file.exists():
174
  with open(self.pid_file, 'r') as f:
 
176
  if not psutil.pid_exists(pid):
177
  # Process died unexpectedly
178
  if status['status'] == 'training':
179
+ # Only log this once by checking if we've already updated the status
180
+ if not hasattr(self, '_process_terminated_logged') or not self._process_terminated_logged:
181
+ self.append_log("Training process terminated unexpectedly")
182
+ self._process_terminated_logged = True
183
  status['status'] = 'error'
184
  status['message'] = 'Training process terminated unexpectedly'
185
+ # Update the status file to avoid repeated logging
186
+ with open(self.status_file, 'w') as f:
187
+ json.dump(status, f, indent=2)
188
  else:
189
  status['status'] = 'stopped'
190
  status['message'] = 'Training process not found'
191
  return status
192
+
193
  except (json.JSONDecodeError, ValueError):
194
  return default_status
195