Spaces:
Running
Running
Commit
·
61a25f0
1
Parent(s):
d2662cc
fix
Browse files- vms/ui/app_ui.py +1 -1
- vms/ui/project/services/training.py +180 -172
- vms/ui/project/tabs/preview_tab.py +7 -11
- vms/ui/project/tabs/train_tab.py +47 -11
vms/ui/app_ui.py
CHANGED
@@ -392,7 +392,7 @@ class AppUI:
|
|
392 |
versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys())
|
393 |
if versions:
|
394 |
model_version_val = versions[0]
|
395 |
-
|
396 |
# Ensure training_type is a valid display name
|
397 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
398 |
if training_type_val not in TRAINING_TYPES:
|
|
|
392 |
versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys())
|
393 |
if versions:
|
394 |
model_version_val = versions[0]
|
395 |
+
|
396 |
# Ensure training_type is a valid display name
|
397 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
398 |
if training_type_val not in TRAINING_TYPES:
|
vms/ui/project/services/training.py
CHANGED
@@ -14,6 +14,7 @@ import zipfile
|
|
14 |
import logging
|
15 |
import traceback
|
16 |
import threading
|
|
|
17 |
import select
|
18 |
|
19 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
@@ -63,6 +64,8 @@ class TrainingService:
|
|
63 |
self.pid_file = OUTPUT_PATH / "training.pid"
|
64 |
self.log_file = OUTPUT_PATH / "training.log"
|
65 |
|
|
|
|
|
66 |
self.file_handler = None
|
67 |
self.setup_logging()
|
68 |
self.ensure_valid_ui_state_file()
|
@@ -131,67 +134,69 @@ class TrainingService:
|
|
131 |
"""Save current UI state to file with validation"""
|
132 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
133 |
|
134 |
-
#
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
validated_values[key] = default_state[key]
|
173 |
-
|
174 |
-
try:
|
175 |
-
validated_values[key] = int(value)
|
176 |
-
except (ValueError, TypeError):
|
177 |
validated_values[key] = default_state[key]
|
178 |
-
|
179 |
-
|
180 |
-
elif key == "lora_alpha" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
181 |
-
validated_values[key] = default_state[key]
|
182 |
-
else:
|
183 |
-
validated_values[key] = value
|
184 |
-
|
185 |
-
try:
|
186 |
-
# First verify we can serialize to JSON
|
187 |
-
json_data = json.dumps(validated_values, indent=2)
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
195 |
|
196 |
def _backup_and_recreate_ui_state(self, ui_state_file, default_state):
|
197 |
"""Backup the corrupted UI state file and create a new one with defaults"""
|
@@ -229,130 +234,133 @@ class TrainingService:
|
|
229 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
|
230 |
}
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
# First check if the file is empty
|
238 |
-
file_size = ui_state_file.stat().st_size
|
239 |
-
if file_size == 0:
|
240 |
-
logger.warning("UI state file exists but is empty, using default values")
|
241 |
return default_state
|
242 |
-
|
243 |
-
with open(ui_state_file, 'r') as f:
|
244 |
-
file_content = f.read().strip()
|
245 |
-
if not file_content:
|
246 |
-
logger.warning("UI state file is empty or contains only whitespace, using default values")
|
247 |
-
return default_state
|
248 |
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
self._backup_and_recreate_ui_state(ui_state_file, default_state)
|
255 |
return default_state
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
# Convert numeric values to appropriate types
|
263 |
-
if "train_steps" in saved_state:
|
264 |
-
try:
|
265 |
-
saved_state["train_steps"] = int(saved_state["train_steps"])
|
266 |
-
except (ValueError, TypeError):
|
267 |
-
saved_state["train_steps"] = default_state["train_steps"]
|
268 |
-
logger.warning("Invalid train_steps value, using default")
|
269 |
-
|
270 |
-
if "batch_size" in saved_state:
|
271 |
-
try:
|
272 |
-
saved_state["batch_size"] = int(saved_state["batch_size"])
|
273 |
-
except (ValueError, TypeError):
|
274 |
-
saved_state["batch_size"] = default_state["batch_size"]
|
275 |
-
logger.warning("Invalid batch_size value, using default")
|
276 |
-
|
277 |
-
if "learning_rate" in saved_state:
|
278 |
-
try:
|
279 |
-
saved_state["learning_rate"] = float(saved_state["learning_rate"])
|
280 |
-
except (ValueError, TypeError):
|
281 |
-
saved_state["learning_rate"] = default_state["learning_rate"]
|
282 |
-
logger.warning("Invalid learning_rate value, using default")
|
283 |
|
284 |
-
if "save_iterations" in saved_state:
|
285 |
try:
|
286 |
-
saved_state
|
287 |
-
except
|
288 |
-
|
289 |
-
|
|
|
|
|
290 |
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
break
|
304 |
-
# If still not found, use default
|
305 |
-
if not model_found:
|
306 |
-
merged_state["model_type"] = default_state["model_type"]
|
307 |
-
logger.warning(f"Invalid model type in saved state, using default")
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
#
|
331 |
-
|
332 |
-
|
333 |
-
logger.warning(f"Invalid training type in saved state, using default")
|
334 |
-
|
335 |
-
# Validate training_preset is in available choices
|
336 |
-
if merged_state["training_preset"] not in TRAINING_PRESETS:
|
337 |
-
merged_state["training_preset"] = default_state["training_preset"]
|
338 |
-
logger.warning(f"Invalid training preset in saved state, using default")
|
339 |
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
def ensure_valid_ui_state_file(self):
|
358 |
"""Ensure UI state file exists and is valid JSON"""
|
|
|
14 |
import logging
|
15 |
import traceback
|
16 |
import threading
|
17 |
+
import fcntl
|
18 |
import select
|
19 |
|
20 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
|
|
64 |
self.pid_file = OUTPUT_PATH / "training.pid"
|
65 |
self.log_file = OUTPUT_PATH / "training.log"
|
66 |
|
67 |
+
self.file_lock = threading.Lock()
|
68 |
+
|
69 |
self.file_handler = None
|
70 |
self.setup_logging()
|
71 |
self.ensure_valid_ui_state_file()
|
|
|
134 |
"""Save current UI state to file with validation"""
|
135 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
136 |
|
137 |
+
# Use a lock to prevent concurrent writes
|
138 |
+
with self.file_lock:
|
139 |
+
# Validate values before saving
|
140 |
+
validated_values = {}
|
141 |
+
default_state = {
|
142 |
+
"model_type": list(MODEL_TYPES.keys())[0],
|
143 |
+
"model_version": "",
|
144 |
+
"training_type": list(TRAINING_TYPES.keys())[0],
|
145 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
146 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
147 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
148 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
149 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
150 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
151 |
+
"training_preset": list(TRAINING_PRESETS.keys())[0],
|
152 |
+
"num_gpus": DEFAULT_NUM_GPUS,
|
153 |
+
"precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
|
154 |
+
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
|
155 |
+
}
|
156 |
+
|
157 |
+
# Copy default values first
|
158 |
+
validated_values = default_state.copy()
|
159 |
+
|
160 |
+
# Update with provided values, converting types as needed
|
161 |
+
for key, value in values.items():
|
162 |
+
if key in default_state:
|
163 |
+
if key == "train_steps":
|
164 |
+
try:
|
165 |
+
validated_values[key] = int(value)
|
166 |
+
except (ValueError, TypeError):
|
167 |
+
validated_values[key] = default_state[key]
|
168 |
+
elif key == "batch_size":
|
169 |
+
try:
|
170 |
+
validated_values[key] = int(value)
|
171 |
+
except (ValueError, TypeError):
|
172 |
+
validated_values[key] = default_state[key]
|
173 |
+
elif key == "learning_rate":
|
174 |
+
try:
|
175 |
+
validated_values[key] = float(value)
|
176 |
+
except (ValueError, TypeError):
|
177 |
+
validated_values[key] = default_state[key]
|
178 |
+
elif key == "save_iterations":
|
179 |
+
try:
|
180 |
+
validated_values[key] = int(value)
|
181 |
+
except (ValueError, TypeError):
|
182 |
+
validated_values[key] = default_state[key]
|
183 |
+
elif key == "lora_rank" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
184 |
validated_values[key] = default_state[key]
|
185 |
+
elif key == "lora_alpha" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
|
|
|
|
|
|
186 |
validated_values[key] = default_state[key]
|
187 |
+
else:
|
188 |
+
validated_values[key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
try:
|
191 |
+
# First verify we can serialize to JSON
|
192 |
+
json_data = json.dumps(validated_values, indent=2)
|
193 |
+
|
194 |
+
# Write to the file
|
195 |
+
with open(ui_state_file, 'w') as f:
|
196 |
+
f.write(json_data)
|
197 |
+
logger.debug(f"UI state saved successfully")
|
198 |
+
except Exception as e:
|
199 |
+
logger.error(f"Error saving UI state: {str(e)}")
|
200 |
|
201 |
def _backup_and_recreate_ui_state(self, ui_state_file, default_state):
|
202 |
"""Backup the corrupted UI state file and create a new one with defaults"""
|
|
|
234 |
"lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
|
235 |
}
|
236 |
|
237 |
+
# Use lock for reading too to avoid reading during a write
|
238 |
+
with self.file_lock:
|
239 |
+
|
240 |
+
if not ui_state_file.exists():
|
241 |
+
logger.info("UI state file does not exist, using default values")
|
|
|
|
|
|
|
|
|
242 |
return default_state
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
+
try:
|
245 |
+
# First check if the file is empty
|
246 |
+
file_size = ui_state_file.stat().st_size
|
247 |
+
if file_size == 0:
|
248 |
+
logger.warning("UI state file exists but is empty, using default values")
|
|
|
249 |
return default_state
|
250 |
+
|
251 |
+
with open(ui_state_file, 'r') as f:
|
252 |
+
file_content = f.read().strip()
|
253 |
+
if not file_content:
|
254 |
+
logger.warning("UI state file is empty or contains only whitespace, using default values")
|
255 |
+
return default_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
|
|
257 |
try:
|
258 |
+
saved_state = json.loads(file_content)
|
259 |
+
except json.JSONDecodeError as e:
|
260 |
+
logger.error(f"Error parsing UI state JSON: {str(e)}")
|
261 |
+
# Instead of showing the error, recreate the file with defaults
|
262 |
+
self._backup_and_recreate_ui_state(ui_state_file, default_state)
|
263 |
+
return default_state
|
264 |
|
265 |
+
# Clean up model type if it contains " (LoRA)" suffix
|
266 |
+
if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
|
267 |
+
saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "")
|
268 |
+
logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}")
|
269 |
+
|
270 |
+
# Convert numeric values to appropriate types
|
271 |
+
if "train_steps" in saved_state:
|
272 |
+
try:
|
273 |
+
saved_state["train_steps"] = int(saved_state["train_steps"])
|
274 |
+
except (ValueError, TypeError):
|
275 |
+
saved_state["train_steps"] = default_state["train_steps"]
|
276 |
+
logger.warning("Invalid train_steps value, using default")
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
+
if "batch_size" in saved_state:
|
279 |
+
try:
|
280 |
+
saved_state["batch_size"] = int(saved_state["batch_size"])
|
281 |
+
except (ValueError, TypeError):
|
282 |
+
saved_state["batch_size"] = default_state["batch_size"]
|
283 |
+
logger.warning("Invalid batch_size value, using default")
|
284 |
+
|
285 |
+
if "learning_rate" in saved_state:
|
286 |
+
try:
|
287 |
+
saved_state["learning_rate"] = float(saved_state["learning_rate"])
|
288 |
+
except (ValueError, TypeError):
|
289 |
+
saved_state["learning_rate"] = default_state["learning_rate"]
|
290 |
+
logger.warning("Invalid learning_rate value, using default")
|
291 |
+
|
292 |
+
if "save_iterations" in saved_state:
|
293 |
+
try:
|
294 |
+
saved_state["save_iterations"] = int(saved_state["save_iterations"])
|
295 |
+
except (ValueError, TypeError):
|
296 |
+
saved_state["save_iterations"] = default_state["save_iterations"]
|
297 |
+
logger.warning("Invalid save_iterations value, using default")
|
298 |
+
|
299 |
+
# Make sure we have all keys (in case structure changed)
|
300 |
+
merged_state = default_state.copy()
|
301 |
+
merged_state.update({k: v for k, v in saved_state.items() if v is not None})
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
+
# Validate model_type is in available choices
|
304 |
+
if merged_state["model_type"] not in MODEL_TYPES:
|
305 |
+
# Try to map from internal name
|
306 |
+
model_found = False
|
307 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
308 |
+
if internal_name == merged_state["model_type"]:
|
309 |
+
merged_state["model_type"] = display_name
|
310 |
+
model_found = True
|
311 |
+
break
|
312 |
+
# If still not found, use default
|
313 |
+
if not model_found:
|
314 |
+
merged_state["model_type"] = default_state["model_type"]
|
315 |
+
logger.warning(f"Invalid model type in saved state, using default")
|
316 |
+
|
317 |
+
# Validate model_version is appropriate for model_type
|
318 |
+
if "model_type" in merged_state and "model_version" in merged_state:
|
319 |
+
model_internal_type = MODEL_TYPES.get(merged_state["model_type"])
|
320 |
+
if model_internal_type:
|
321 |
+
valid_versions = MODEL_VERSIONS.get(model_internal_type, {}).keys()
|
322 |
+
if merged_state["model_version"] not in valid_versions:
|
323 |
+
# Set to default for this model type
|
324 |
+
from vms.ui.project.tabs.train_tab import TrainTab
|
325 |
+
train_tab = TrainTab(None) # Temporary instance just for the helper method
|
326 |
+
merged_state["model_version"] = train_tab.get_default_model_version(saved_state["model_type"])
|
327 |
+
logger.warning(f"Invalid model version for {merged_state['model_type']}, using default")
|
328 |
|
329 |
+
# Validate training_type is in available choices
|
330 |
+
if merged_state["training_type"] not in TRAINING_TYPES:
|
331 |
+
# Try to map from internal name
|
332 |
+
training_found = False
|
333 |
+
for display_name, internal_name in TRAINING_TYPES.items():
|
334 |
+
if internal_name == merged_state["training_type"]:
|
335 |
+
merged_state["training_type"] = display_name
|
336 |
+
training_found = True
|
337 |
+
break
|
338 |
+
# If still not found, use default
|
339 |
+
if not training_found:
|
340 |
+
merged_state["training_type"] = default_state["training_type"]
|
341 |
+
logger.warning(f"Invalid training type in saved state, using default")
|
342 |
|
343 |
+
# Validate training_preset is in available choices
|
344 |
+
if merged_state["training_preset"] not in TRAINING_PRESETS:
|
345 |
+
merged_state["training_preset"] = default_state["training_preset"]
|
346 |
+
logger.warning(f"Invalid training preset in saved state, using default")
|
347 |
+
|
348 |
+
# Validate lora_rank is in allowed values
|
349 |
+
if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
350 |
+
merged_state["lora_rank"] = default_state["lora_rank"]
|
351 |
+
logger.warning(f"Invalid lora_rank in saved state, using default")
|
352 |
+
|
353 |
+
# Validate lora_alpha is in allowed values
|
354 |
+
if merged_state.get("lora_alpha") not in ["16", "32", "64", "128", "256", "512", "1024"]:
|
355 |
+
merged_state["lora_alpha"] = default_state["lora_alpha"]
|
356 |
+
logger.warning(f"Invalid lora_alpha in saved state, using default")
|
357 |
+
|
358 |
+
return merged_state
|
359 |
+
except Exception as e:
|
360 |
+
logger.error(f"Error loading UI state: {str(e)}")
|
361 |
+
# If anything goes wrong, backup and recreate
|
362 |
+
self._backup_and_recreate_ui_state(ui_state_file, default_state)
|
363 |
+
return default_state
|
364 |
|
365 |
def ensure_valid_ui_state_file(self):
|
366 |
"""Ensure UI state file exists and is valid JSON"""
|
vms/ui/project/tabs/preview_tab.py
CHANGED
@@ -298,7 +298,7 @@ class PreviewTab(BaseTab):
|
|
298 |
# Update model_version choices when model_type changes or tab is selected
|
299 |
if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
|
300 |
self.app.tabs_component.select(
|
301 |
-
fn=self.
|
302 |
inputs=[],
|
303 |
outputs=[
|
304 |
self.components["model_type"],
|
@@ -391,7 +391,7 @@ class PreviewTab(BaseTab):
|
|
391 |
self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image)
|
392 |
}
|
393 |
|
394 |
-
def
|
395 |
"""Sync model type with training tab when preview tab is selected and update model version choices"""
|
396 |
model_type = self.get_default_model_type()
|
397 |
model_version = ""
|
@@ -401,19 +401,15 @@ class PreviewTab(BaseTab):
|
|
401 |
preview_state = ui_state.get("preview", {})
|
402 |
model_version = preview_state.get("model_version", "")
|
403 |
|
|
|
404 |
if not model_version:
|
405 |
-
#
|
406 |
internal_type = MODEL_TYPES.get(model_type)
|
407 |
if internal_type and internal_type in MODEL_VERSIONS:
|
408 |
-
|
409 |
-
if
|
410 |
-
|
411 |
-
model_version = f"{first_version} - {model_version_info.get('name', '')}"
|
412 |
|
413 |
-
# If we couldn't get it, use default
|
414 |
-
if not model_version:
|
415 |
-
model_version = self.get_default_model_version(model_type)
|
416 |
-
|
417 |
return model_type, model_version
|
418 |
|
419 |
def update_resolution(self, preset: str) -> Tuple[int, int, float]:
|
|
|
298 |
# Update model_version choices when model_type changes or tab is selected
|
299 |
if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
|
300 |
self.app.tabs_component.select(
|
301 |
+
fn=self.sync_model_type_and_versions,
|
302 |
inputs=[],
|
303 |
outputs=[
|
304 |
self.components["model_type"],
|
|
|
391 |
self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image)
|
392 |
}
|
393 |
|
394 |
+
def sync_model_type_and_versions(self) -> Tuple[str, str]:
|
395 |
"""Sync model type with training tab when preview tab is selected and update model version choices"""
|
396 |
model_type = self.get_default_model_type()
|
397 |
model_version = ""
|
|
|
401 |
preview_state = ui_state.get("preview", {})
|
402 |
model_version = preview_state.get("model_version", "")
|
403 |
|
404 |
+
# If no model version specified or invalid, use default
|
405 |
if not model_version:
|
406 |
+
# Get the internal model type
|
407 |
internal_type = MODEL_TYPES.get(model_type)
|
408 |
if internal_type and internal_type in MODEL_VERSIONS:
|
409 |
+
versions = list(MODEL_VERSIONS[internal_type].keys())
|
410 |
+
if versions:
|
411 |
+
model_version = versions[0]
|
|
|
412 |
|
|
|
|
|
|
|
|
|
413 |
return model_type, model_version
|
414 |
|
415 |
def update_resolution(self, preset: str) -> Tuple[int, int, float]:
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -69,7 +69,7 @@ class TrainTab(BaseTab):
|
|
69 |
# Get model versions for the default model type
|
70 |
default_model_versions = self.get_model_version_choices(default_model_type)
|
71 |
default_model_version = self.get_default_model_version(default_model_type)
|
72 |
-
|
73 |
self.components["model_version"] = gr.Dropdown(
|
74 |
choices=default_model_versions,
|
75 |
label="Model Version",
|
@@ -214,6 +214,37 @@ class TrainTab(BaseTab):
|
|
214 |
|
215 |
return tab
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
def connect_events(self) -> None:
|
218 |
"""Connect event handlers to UI components"""
|
219 |
# Model type change event - Update model version dropdown choices
|
@@ -222,8 +253,8 @@ class TrainTab(BaseTab):
|
|
222 |
inputs=[self.components["model_type"]],
|
223 |
outputs=[self.components["model_version"]]
|
224 |
).then(
|
225 |
-
fn=
|
226 |
-
inputs=[self.components["model_type"]],
|
227 |
outputs=[]
|
228 |
).then(
|
229 |
# Use get_model_info instead of update_model_info
|
@@ -234,8 +265,8 @@ class TrainTab(BaseTab):
|
|
234 |
|
235 |
# Model version change event
|
236 |
self.components["model_version"].change(
|
237 |
-
fn=
|
238 |
-
inputs=[self.components["model_version"]],
|
239 |
outputs=[]
|
240 |
)
|
241 |
|
@@ -399,10 +430,13 @@ class TrainTab(BaseTab):
|
|
399 |
"""Update model version choices based on selected model type"""
|
400 |
model_versions = self.get_model_version_choices(model_type)
|
401 |
default_version = self.get_default_model_version(model_type)
|
|
|
|
|
|
|
402 |
|
403 |
# Update the model_version dropdown with new choices and default value
|
404 |
return gr.Dropdown(choices=model_versions, value=default_version)
|
405 |
-
|
406 |
def handle_training_start(
|
407 |
self, preset, model_type, model_version, training_type,
|
408 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
@@ -477,22 +511,24 @@ class TrainTab(BaseTab):
|
|
477 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
478 |
return []
|
479 |
|
480 |
-
#
|
481 |
-
|
482 |
-
|
483 |
|
484 |
def get_default_model_version(self, model_type: str) -> str:
|
485 |
"""Get default model version for the given model type"""
|
486 |
# Convert UI display name to internal name
|
487 |
internal_type = MODEL_TYPES.get(model_type)
|
|
|
488 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
489 |
return ""
|
490 |
|
491 |
# Get the first version available for this model type
|
492 |
versions = MODEL_VERSIONS.get(internal_type, {})
|
493 |
if versions:
|
494 |
-
|
495 |
-
|
|
|
496 |
return ""
|
497 |
|
498 |
def update_model_info(self, model_type: str, training_type: str) -> Dict:
|
|
|
69 |
# Get model versions for the default model type
|
70 |
default_model_versions = self.get_model_version_choices(default_model_type)
|
71 |
default_model_version = self.get_default_model_version(default_model_type)
|
72 |
+
print(f"default_model_version(default_model_type) = {default_model_version}")
|
73 |
self.components["model_version"] = gr.Dropdown(
|
74 |
choices=default_model_versions,
|
75 |
label="Model Version",
|
|
|
214 |
|
215 |
return tab
|
216 |
|
217 |
+
def update_model_type_and_version(self, model_type: str, model_version: str):
|
218 |
+
"""Update both model type and version together to keep them in sync"""
|
219 |
+
# Get internal model type
|
220 |
+
internal_type = MODEL_TYPES.get(model_type)
|
221 |
+
|
222 |
+
# Make sure model_version is valid for this model type
|
223 |
+
if internal_type and internal_type in MODEL_VERSIONS:
|
224 |
+
valid_versions = list(MODEL_VERSIONS[internal_type].keys())
|
225 |
+
if not model_version or model_version not in valid_versions:
|
226 |
+
if valid_versions:
|
227 |
+
model_version = valid_versions[0]
|
228 |
+
|
229 |
+
# Update UI state with both values to keep them in sync
|
230 |
+
self.app.update_ui_state(model_type=model_type, model_version=model_version)
|
231 |
+
return None
|
232 |
+
|
233 |
+
def save_model_version(self, model_type: str, model_version: str):
|
234 |
+
"""Save model version ensuring it's consistent with model type"""
|
235 |
+
internal_type = MODEL_TYPES.get(model_type)
|
236 |
+
|
237 |
+
# Verify the model_version is compatible with the current model_type
|
238 |
+
if internal_type and internal_type in MODEL_VERSIONS:
|
239 |
+
valid_versions = MODEL_VERSIONS[internal_type].keys()
|
240 |
+
if model_version not in valid_versions:
|
241 |
+
# Don't save incompatible version
|
242 |
+
return None
|
243 |
+
|
244 |
+
# Save the model version along with current model type to ensure consistency
|
245 |
+
self.app.update_ui_state(model_type=model_type, model_version=model_version)
|
246 |
+
return None
|
247 |
+
|
248 |
def connect_events(self) -> None:
|
249 |
"""Connect event handlers to UI components"""
|
250 |
# Model type change event - Update model version dropdown choices
|
|
|
253 |
inputs=[self.components["model_type"]],
|
254 |
outputs=[self.components["model_version"]]
|
255 |
).then(
|
256 |
+
fn=self.update_model_type_and_version, # Add this new function
|
257 |
+
inputs=[self.components["model_type"], self.components["model_version"]],
|
258 |
outputs=[]
|
259 |
).then(
|
260 |
# Use get_model_info instead of update_model_info
|
|
|
265 |
|
266 |
# Model version change event
|
267 |
self.components["model_version"].change(
|
268 |
+
fn=self.save_model_version, # Replace with this new function
|
269 |
+
inputs=[self.components["model_type"], self.components["model_version"]],
|
270 |
outputs=[]
|
271 |
)
|
272 |
|
|
|
430 |
"""Update model version choices based on selected model type"""
|
431 |
model_versions = self.get_model_version_choices(model_type)
|
432 |
default_version = self.get_default_model_version(model_type)
|
433 |
+
print(f"update_model_versions({model_type}): default_version = {default_version}")
|
434 |
+
# Update UI state with proper model_type first (add this line)
|
435 |
+
self.app.update_ui_state(model_type=model_type)
|
436 |
|
437 |
# Update the model_version dropdown with new choices and default value
|
438 |
return gr.Dropdown(choices=model_versions, value=default_version)
|
439 |
+
|
440 |
def handle_training_start(
|
441 |
self, preset, model_type, model_version, training_type,
|
442 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
|
|
511 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
512 |
return []
|
513 |
|
514 |
+
# Return just the model IDs without formatting
|
515 |
+
return list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
516 |
+
|
517 |
|
518 |
def get_default_model_version(self, model_type: str) -> str:
|
519 |
"""Get default model version for the given model type"""
|
520 |
# Convert UI display name to internal name
|
521 |
internal_type = MODEL_TYPES.get(model_type)
|
522 |
+
print(f"get_default_model_version({model_type}) = {internal_type}")
|
523 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
524 |
return ""
|
525 |
|
526 |
# Get the first version available for this model type
|
527 |
versions = MODEL_VERSIONS.get(internal_type, {})
|
528 |
if versions:
|
529 |
+
model_versions = list(versions.keys())
|
530 |
+
if model_versions:
|
531 |
+
return model_versions[0]
|
532 |
return ""
|
533 |
|
534 |
def update_model_info(self, model_type: str, training_type: str) -> Dict:
|