Spaces:
Running
Running
Commit
·
446e79f
1
Parent(s):
54a2a4e
working on fixes
Browse files- app.py +72 -37
- vms/training_service.py +10 -5
app.py
CHANGED
@@ -77,12 +77,68 @@ class VideoTrainerUI:
|
|
77 |
# UI will be in ready-to-start mode
|
78 |
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
def update_ui_state(self, **kwargs):
|
81 |
"""Update UI state with new values"""
|
82 |
current_state = self.trainer.load_ui_state()
|
83 |
current_state.update(kwargs)
|
84 |
self.trainer.save_ui_state(current_state)
|
85 |
-
return
|
|
|
86 |
|
87 |
def load_ui_values(self):
|
88 |
"""Load UI state values for initializing form fields"""
|
@@ -130,6 +186,19 @@ class VideoTrainerUI:
|
|
130 |
)
|
131 |
)
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
def show_refreshing_status(self) -> List[List[str]]:
|
134 |
"""Show a 'Refreshing...' status in the dataframe"""
|
135 |
return [["Refreshing...", "please wait"]]
|
@@ -1421,52 +1490,18 @@ class VideoTrainerUI:
|
|
1421 |
]
|
1422 |
)
|
1423 |
|
1424 |
-
# Add this new method to get initial button states:
|
1425 |
-
def get_initial_button_states(self):
|
1426 |
-
"""Get the initial states for training buttons based on recovery status"""
|
1427 |
-
recovery_result = self.trainer.recover_interrupted_training()
|
1428 |
-
ui_updates = recovery_result.get("ui_updates", {})
|
1429 |
-
|
1430 |
-
# Return button states in the correct order
|
1431 |
-
return (
|
1432 |
-
gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
|
1433 |
-
gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
|
1434 |
-
gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
|
1435 |
-
)
|
1436 |
-
|
1437 |
-
def initialize_ui_from_state(self):
|
1438 |
-
"""Initialize UI components from saved state"""
|
1439 |
-
ui_state = self.load_ui_values()
|
1440 |
-
|
1441 |
-
# Return values in order matching the outputs in app.load
|
1442 |
-
return (
|
1443 |
-
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
1444 |
-
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
1445 |
-
ui_state.get("lora_rank", "128"),
|
1446 |
-
ui_state.get("lora_alpha", "128"),
|
1447 |
-
ui_state.get("num_epochs", 70),
|
1448 |
-
ui_state.get("batch_size", 1),
|
1449 |
-
ui_state.get("learning_rate", 3e-5),
|
1450 |
-
ui_state.get("save_iterations", 500)
|
1451 |
-
)
|
1452 |
|
1453 |
-
# Auto-refresh timers
|
1454 |
app.load(
|
1455 |
-
fn=
|
1456 |
-
self.refresh_dataset(),
|
1457 |
-
*self.get_initial_button_states(),
|
1458 |
-
# Load saved UI state values
|
1459 |
-
*self.initialize_ui_from_state()
|
1460 |
-
),
|
1461 |
outputs=[
|
1462 |
video_list, training_dataset,
|
1463 |
start_btn, stop_btn, pause_resume_btn,
|
1464 |
-
# Add outputs for UI fields
|
1465 |
training_preset, model_type, lora_rank, lora_alpha,
|
1466 |
num_epochs, batch_size, learning_rate, save_iterations
|
1467 |
]
|
1468 |
)
|
1469 |
|
|
|
1470 |
timer = gr.Timer(value=1)
|
1471 |
timer.tick(
|
1472 |
fn=lambda: (
|
|
|
77 |
# UI will be in ready-to-start mode
|
78 |
|
79 |
|
80 |
+
def initialize_app_state(self):
|
81 |
+
"""Initialize all app state in one function to ensure correct output count"""
|
82 |
+
# Get dataset info
|
83 |
+
video_list, training_dataset = self.refresh_dataset()
|
84 |
+
|
85 |
+
# Get button states
|
86 |
+
button_states = self.get_initial_button_states()
|
87 |
+
start_btn = button_states[0]
|
88 |
+
stop_btn = button_states[1]
|
89 |
+
pause_resume_btn = button_states[2]
|
90 |
+
|
91 |
+
# Get UI form values
|
92 |
+
ui_state = self.load_ui_values()
|
93 |
+
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
94 |
+
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
95 |
+
lora_rank_val = ui_state.get("lora_rank", "128")
|
96 |
+
lora_alpha_val = ui_state.get("lora_alpha", "128")
|
97 |
+
num_epochs_val = int(ui_state.get("num_epochs", 70))
|
98 |
+
batch_size_val = int(ui_state.get("batch_size", 1))
|
99 |
+
learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
|
100 |
+
save_iterations_val = int(ui_state.get("save_iterations", 500))
|
101 |
+
|
102 |
+
# Return all values in the exact order expected by outputs
|
103 |
+
return (
|
104 |
+
video_list,
|
105 |
+
training_dataset,
|
106 |
+
start_btn,
|
107 |
+
stop_btn,
|
108 |
+
pause_resume_btn,
|
109 |
+
training_preset,
|
110 |
+
model_type_val,
|
111 |
+
lora_rank_val,
|
112 |
+
lora_alpha_val,
|
113 |
+
num_epochs_val,
|
114 |
+
batch_size_val,
|
115 |
+
learning_rate_val,
|
116 |
+
save_iterations_val
|
117 |
+
)
|
118 |
+
|
119 |
+
def initialize_ui_from_state(self):
|
120 |
+
"""Initialize UI components from saved state"""
|
121 |
+
ui_state = self.load_ui_values()
|
122 |
+
|
123 |
+
# Return values in order matching the outputs in app.load
|
124 |
+
return (
|
125 |
+
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
126 |
+
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
127 |
+
ui_state.get("lora_rank", "128"),
|
128 |
+
ui_state.get("lora_alpha", "128"),
|
129 |
+
ui_state.get("num_epochs", 70),
|
130 |
+
ui_state.get("batch_size", 1),
|
131 |
+
ui_state.get("learning_rate", 3e-5),
|
132 |
+
ui_state.get("save_iterations", 500)
|
133 |
+
)
|
134 |
+
|
135 |
def update_ui_state(self, **kwargs):
|
136 |
"""Update UI state with new values"""
|
137 |
current_state = self.trainer.load_ui_state()
|
138 |
current_state.update(kwargs)
|
139 |
self.trainer.save_ui_state(current_state)
|
140 |
+
# Don't return anything to avoid Gradio warnings
|
141 |
+
return None
|
142 |
|
143 |
def load_ui_values(self):
|
144 |
"""Load UI state values for initializing form fields"""
|
|
|
186 |
)
|
187 |
)
|
188 |
|
189 |
+
# Add this new method to get initial button states:
|
190 |
+
def get_initial_button_states(self):
|
191 |
+
"""Get the initial states for training buttons based on recovery status"""
|
192 |
+
recovery_result = self.trainer.recover_interrupted_training()
|
193 |
+
ui_updates = recovery_result.get("ui_updates", {})
|
194 |
+
|
195 |
+
# Return button states in the correct order
|
196 |
+
return (
|
197 |
+
gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
|
198 |
+
gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
|
199 |
+
gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
|
200 |
+
)
|
201 |
+
|
202 |
def show_refreshing_status(self) -> List[List[str]]:
|
203 |
"""Show a 'Refreshing...' status in the dataframe"""
|
204 |
return [["Refreshing...", "please wait"]]
|
|
|
1490 |
]
|
1491 |
)
|
1492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1493 |
|
|
|
1494 |
app.load(
|
1495 |
+
fn=self.initialize_app_state,
|
|
|
|
|
|
|
|
|
|
|
1496 |
outputs=[
|
1497 |
video_list, training_dataset,
|
1498 |
start_btn, stop_btn, pause_resume_btn,
|
|
|
1499 |
training_preset, model_type, lora_rank, lora_alpha,
|
1500 |
num_epochs, batch_size, learning_rate, save_iterations
|
1501 |
]
|
1502 |
)
|
1503 |
|
1504 |
+
# Auto-refresh timers
|
1505 |
timer = gr.Timer(value=1)
|
1506 |
timer.tick(
|
1507 |
fn=lambda: (
|
vms/training_service.py
CHANGED
@@ -164,12 +164,11 @@ class TrainingService:
|
|
164 |
|
165 |
if not self.status_file.exists():
|
166 |
return default_status
|
167 |
-
|
168 |
try:
|
169 |
with open(self.status_file, 'r') as f:
|
170 |
status = json.load(f)
|
171 |
-
|
172 |
-
|
173 |
# Check if process is actually running
|
174 |
if self.pid_file.exists():
|
175 |
with open(self.pid_file, 'r') as f:
|
@@ -177,14 +176,20 @@ class TrainingService:
|
|
177 |
if not psutil.pid_exists(pid):
|
178 |
# Process died unexpectedly
|
179 |
if status['status'] == 'training':
|
|
|
|
|
|
|
|
|
180 |
status['status'] = 'error'
|
181 |
status['message'] = 'Training process terminated unexpectedly'
|
182 |
-
|
|
|
|
|
183 |
else:
|
184 |
status['status'] = 'stopped'
|
185 |
status['message'] = 'Training process not found'
|
186 |
return status
|
187 |
-
|
188 |
except (json.JSONDecodeError, ValueError):
|
189 |
return default_status
|
190 |
|
|
|
164 |
|
165 |
if not self.status_file.exists():
|
166 |
return default_status
|
167 |
+
|
168 |
try:
|
169 |
with open(self.status_file, 'r') as f:
|
170 |
status = json.load(f)
|
171 |
+
|
|
|
172 |
# Check if process is actually running
|
173 |
if self.pid_file.exists():
|
174 |
with open(self.pid_file, 'r') as f:
|
|
|
176 |
if not psutil.pid_exists(pid):
|
177 |
# Process died unexpectedly
|
178 |
if status['status'] == 'training':
|
179 |
+
# Only log this once by checking if we've already updated the status
|
180 |
+
if not hasattr(self, '_process_terminated_logged') or not self._process_terminated_logged:
|
181 |
+
self.append_log("Training process terminated unexpectedly")
|
182 |
+
self._process_terminated_logged = True
|
183 |
status['status'] = 'error'
|
184 |
status['message'] = 'Training process terminated unexpectedly'
|
185 |
+
# Update the status file to avoid repeated logging
|
186 |
+
with open(self.status_file, 'w') as f:
|
187 |
+
json.dump(status, f, indent=2)
|
188 |
else:
|
189 |
status['status'] = 'stopped'
|
190 |
status['message'] = 'Training process not found'
|
191 |
return status
|
192 |
+
|
193 |
except (json.JSONDecodeError, ValueError):
|
194 |
return default_status
|
195 |
|