jbilcke-hf HF Staff commited on
Commit
61a25f0
·
1 Parent(s): d2662cc
vms/ui/app_ui.py CHANGED
@@ -392,7 +392,7 @@ class AppUI:
392
  versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys())
393
  if versions:
394
  model_version_val = versions[0]
395
-
396
  # Ensure training_type is a valid display name
397
  training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
398
  if training_type_val not in TRAINING_TYPES:
 
392
  versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys())
393
  if versions:
394
  model_version_val = versions[0]
395
+
396
  # Ensure training_type is a valid display name
397
  training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
398
  if training_type_val not in TRAINING_TYPES:
vms/ui/project/services/training.py CHANGED
@@ -14,6 +14,7 @@ import zipfile
14
  import logging
15
  import traceback
16
  import threading
 
17
  import select
18
 
19
  from typing import Any, Optional, Dict, List, Union, Tuple
@@ -63,6 +64,8 @@ class TrainingService:
63
  self.pid_file = OUTPUT_PATH / "training.pid"
64
  self.log_file = OUTPUT_PATH / "training.log"
65
 
 
 
66
  self.file_handler = None
67
  self.setup_logging()
68
  self.ensure_valid_ui_state_file()
@@ -131,67 +134,69 @@ class TrainingService:
131
  """Save current UI state to file with validation"""
132
  ui_state_file = OUTPUT_PATH / "ui_state.json"
133
 
134
- # Validate values before saving
135
- validated_values = {}
136
- default_state = {
137
- "model_type": list(MODEL_TYPES.keys())[0],
138
- "model_version": "",
139
- "training_type": list(TRAINING_TYPES.keys())[0],
140
- "lora_rank": DEFAULT_LORA_RANK_STR,
141
- "lora_alpha": DEFAULT_LORA_ALPHA_STR,
142
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
143
- "batch_size": DEFAULT_BATCH_SIZE,
144
- "learning_rate": DEFAULT_LEARNING_RATE,
145
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
146
- "training_preset": list(TRAINING_PRESETS.keys())[0],
147
- "num_gpus": DEFAULT_NUM_GPUS,
148
- "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
149
- "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
150
- }
151
-
152
- # Copy default values first
153
- validated_values = default_state.copy()
154
-
155
- # Update with provided values, converting types as needed
156
- for key, value in values.items():
157
- if key in default_state:
158
- if key == "train_steps":
159
- try:
160
- validated_values[key] = int(value)
161
- except (ValueError, TypeError):
162
- validated_values[key] = default_state[key]
163
- elif key == "batch_size":
164
- try:
165
- validated_values[key] = int(value)
166
- except (ValueError, TypeError):
167
- validated_values[key] = default_state[key]
168
- elif key == "learning_rate":
169
- try:
170
- validated_values[key] = float(value)
171
- except (ValueError, TypeError):
 
 
 
 
 
 
 
 
 
172
  validated_values[key] = default_state[key]
173
- elif key == "save_iterations":
174
- try:
175
- validated_values[key] = int(value)
176
- except (ValueError, TypeError):
177
  validated_values[key] = default_state[key]
178
- elif key == "lora_rank" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
179
- validated_values[key] = default_state[key]
180
- elif key == "lora_alpha" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
181
- validated_values[key] = default_state[key]
182
- else:
183
- validated_values[key] = value
184
-
185
- try:
186
- # First verify we can serialize to JSON
187
- json_data = json.dumps(validated_values, indent=2)
188
 
189
- # Write to the file
190
- with open(ui_state_file, 'w') as f:
191
- f.write(json_data)
192
- logger.debug(f"UI state saved successfully")
193
- except Exception as e:
194
- logger.error(f"Error saving UI state: {str(e)}")
 
 
 
 
195
 
196
  def _backup_and_recreate_ui_state(self, ui_state_file, default_state):
197
  """Backup the corrupted UI state file and create a new one with defaults"""
@@ -229,130 +234,133 @@ class TrainingService:
229
  "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
230
  }
231
 
232
- if not ui_state_file.exists():
233
- logger.info("UI state file does not exist, using default values")
234
- return default_state
235
-
236
- try:
237
- # First check if the file is empty
238
- file_size = ui_state_file.stat().st_size
239
- if file_size == 0:
240
- logger.warning("UI state file exists but is empty, using default values")
241
  return default_state
242
-
243
- with open(ui_state_file, 'r') as f:
244
- file_content = f.read().strip()
245
- if not file_content:
246
- logger.warning("UI state file is empty or contains only whitespace, using default values")
247
- return default_state
248
 
249
- try:
250
- saved_state = json.loads(file_content)
251
- except json.JSONDecodeError as e:
252
- logger.error(f"Error parsing UI state JSON: {str(e)}")
253
- # Instead of showing the error, recreate the file with defaults
254
- self._backup_and_recreate_ui_state(ui_state_file, default_state)
255
  return default_state
256
-
257
- # Clean up model type if it contains " (LoRA)" suffix
258
- if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
259
- saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "")
260
- logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}")
261
-
262
- # Convert numeric values to appropriate types
263
- if "train_steps" in saved_state:
264
- try:
265
- saved_state["train_steps"] = int(saved_state["train_steps"])
266
- except (ValueError, TypeError):
267
- saved_state["train_steps"] = default_state["train_steps"]
268
- logger.warning("Invalid train_steps value, using default")
269
-
270
- if "batch_size" in saved_state:
271
- try:
272
- saved_state["batch_size"] = int(saved_state["batch_size"])
273
- except (ValueError, TypeError):
274
- saved_state["batch_size"] = default_state["batch_size"]
275
- logger.warning("Invalid batch_size value, using default")
276
-
277
- if "learning_rate" in saved_state:
278
- try:
279
- saved_state["learning_rate"] = float(saved_state["learning_rate"])
280
- except (ValueError, TypeError):
281
- saved_state["learning_rate"] = default_state["learning_rate"]
282
- logger.warning("Invalid learning_rate value, using default")
283
 
284
- if "save_iterations" in saved_state:
285
  try:
286
- saved_state["save_iterations"] = int(saved_state["save_iterations"])
287
- except (ValueError, TypeError):
288
- saved_state["save_iterations"] = default_state["save_iterations"]
289
- logger.warning("Invalid save_iterations value, using default")
 
 
290
 
291
- # Make sure we have all keys (in case structure changed)
292
- merged_state = default_state.copy()
293
- merged_state.update({k: v for k, v in saved_state.items() if v is not None})
294
-
295
- # Validate model_type is in available choices
296
- if merged_state["model_type"] not in MODEL_TYPES:
297
- # Try to map from internal name
298
- model_found = False
299
- for display_name, internal_name in MODEL_TYPES.items():
300
- if internal_name == merged_state["model_type"]:
301
- merged_state["model_type"] = display_name
302
- model_found = True
303
- break
304
- # If still not found, use default
305
- if not model_found:
306
- merged_state["model_type"] = default_state["model_type"]
307
- logger.warning(f"Invalid model type in saved state, using default")
308
 
309
- # Validate model_version is appropriate for model_type
310
- if "model_type" in merged_state and "model_version" in merged_state:
311
- model_internal_type = MODEL_TYPES.get(merged_state["model_type"])
312
- if model_internal_type:
313
- valid_versions = MODEL_VERSIONS.get(model_internal_type, {}).keys()
314
- if merged_state["model_version"] not in valid_versions:
315
- # Set to default for this model type
316
- from vms.ui.project.tabs.train_tab import TrainTab
317
- train_tab = TrainTab(None) # Temporary instance just for the helper method
318
- merged_state["model_version"] = train_tab.get_default_model_version(saved_state["model_type"])
319
- logger.warning(f"Invalid model version for {merged_state['model_type']}, using default")
320
-
321
- # Validate training_type is in available choices
322
- if merged_state["training_type"] not in TRAINING_TYPES:
323
- # Try to map from internal name
324
- training_found = False
325
- for display_name, internal_name in TRAINING_TYPES.items():
326
- if internal_name == merged_state["training_type"]:
327
- merged_state["training_type"] = display_name
328
- training_found = True
329
- break
330
- # If still not found, use default
331
- if not training_found:
332
- merged_state["training_type"] = default_state["training_type"]
333
- logger.warning(f"Invalid training type in saved state, using default")
334
-
335
- # Validate training_preset is in available choices
336
- if merged_state["training_preset"] not in TRAINING_PRESETS:
337
- merged_state["training_preset"] = default_state["training_preset"]
338
- logger.warning(f"Invalid training preset in saved state, using default")
339
 
340
- # Validate lora_rank is in allowed values
341
- if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]:
342
- merged_state["lora_rank"] = default_state["lora_rank"]
343
- logger.warning(f"Invalid lora_rank in saved state, using default")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
- # Validate lora_alpha is in allowed values
346
- if merged_state.get("lora_alpha") not in ["16", "32", "64", "128", "256", "512", "1024"]:
347
- merged_state["lora_alpha"] = default_state["lora_alpha"]
348
- logger.warning(f"Invalid lora_alpha in saved state, using default")
 
 
 
 
 
 
 
 
 
349
 
350
- return merged_state
351
- except Exception as e:
352
- logger.error(f"Error loading UI state: {str(e)}")
353
- # If anything goes wrong, backup and recreate
354
- self._backup_and_recreate_ui_state(ui_state_file, default_state)
355
- return default_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  def ensure_valid_ui_state_file(self):
358
  """Ensure UI state file exists and is valid JSON"""
 
14
  import logging
15
  import traceback
16
  import threading
17
+ import fcntl
18
  import select
19
 
20
  from typing import Any, Optional, Dict, List, Union, Tuple
 
64
  self.pid_file = OUTPUT_PATH / "training.pid"
65
  self.log_file = OUTPUT_PATH / "training.log"
66
 
67
+ self.file_lock = threading.Lock()
68
+
69
  self.file_handler = None
70
  self.setup_logging()
71
  self.ensure_valid_ui_state_file()
 
134
  """Save current UI state to file with validation"""
135
  ui_state_file = OUTPUT_PATH / "ui_state.json"
136
 
137
+ # Use a lock to prevent concurrent writes
138
+ with self.file_lock:
139
+ # Validate values before saving
140
+ validated_values = {}
141
+ default_state = {
142
+ "model_type": list(MODEL_TYPES.keys())[0],
143
+ "model_version": "",
144
+ "training_type": list(TRAINING_TYPES.keys())[0],
145
+ "lora_rank": DEFAULT_LORA_RANK_STR,
146
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR,
147
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
148
+ "batch_size": DEFAULT_BATCH_SIZE,
149
+ "learning_rate": DEFAULT_LEARNING_RATE,
150
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
151
+ "training_preset": list(TRAINING_PRESETS.keys())[0],
152
+ "num_gpus": DEFAULT_NUM_GPUS,
153
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
154
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
155
+ }
156
+
157
+ # Copy default values first
158
+ validated_values = default_state.copy()
159
+
160
+ # Update with provided values, converting types as needed
161
+ for key, value in values.items():
162
+ if key in default_state:
163
+ if key == "train_steps":
164
+ try:
165
+ validated_values[key] = int(value)
166
+ except (ValueError, TypeError):
167
+ validated_values[key] = default_state[key]
168
+ elif key == "batch_size":
169
+ try:
170
+ validated_values[key] = int(value)
171
+ except (ValueError, TypeError):
172
+ validated_values[key] = default_state[key]
173
+ elif key == "learning_rate":
174
+ try:
175
+ validated_values[key] = float(value)
176
+ except (ValueError, TypeError):
177
+ validated_values[key] = default_state[key]
178
+ elif key == "save_iterations":
179
+ try:
180
+ validated_values[key] = int(value)
181
+ except (ValueError, TypeError):
182
+ validated_values[key] = default_state[key]
183
+ elif key == "lora_rank" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
184
  validated_values[key] = default_state[key]
185
+ elif key == "lora_alpha" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
 
 
 
186
  validated_values[key] = default_state[key]
187
+ else:
188
+ validated_values[key] = value
 
 
 
 
 
 
 
 
189
 
190
+ try:
191
+ # First verify we can serialize to JSON
192
+ json_data = json.dumps(validated_values, indent=2)
193
+
194
+ # Write to the file
195
+ with open(ui_state_file, 'w') as f:
196
+ f.write(json_data)
197
+ logger.debug(f"UI state saved successfully")
198
+ except Exception as e:
199
+ logger.error(f"Error saving UI state: {str(e)}")
200
 
201
  def _backup_and_recreate_ui_state(self, ui_state_file, default_state):
202
  """Backup the corrupted UI state file and create a new one with defaults"""
 
234
  "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
235
  }
236
 
237
+ # Use lock for reading too to avoid reading during a write
238
+ with self.file_lock:
239
+
240
+ if not ui_state_file.exists():
241
+ logger.info("UI state file does not exist, using default values")
 
 
 
 
242
  return default_state
 
 
 
 
 
 
243
 
244
+ try:
245
+ # First check if the file is empty
246
+ file_size = ui_state_file.stat().st_size
247
+ if file_size == 0:
248
+ logger.warning("UI state file exists but is empty, using default values")
 
249
  return default_state
250
+
251
+ with open(ui_state_file, 'r') as f:
252
+ file_content = f.read().strip()
253
+ if not file_content:
254
+ logger.warning("UI state file is empty or contains only whitespace, using default values")
255
+ return default_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
 
257
  try:
258
+ saved_state = json.loads(file_content)
259
+ except json.JSONDecodeError as e:
260
+ logger.error(f"Error parsing UI state JSON: {str(e)}")
261
+ # Instead of showing the error, recreate the file with defaults
262
+ self._backup_and_recreate_ui_state(ui_state_file, default_state)
263
+ return default_state
264
 
265
+ # Clean up model type if it contains " (LoRA)" suffix
266
+ if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
267
+ saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "")
268
+ logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}")
269
+
270
+ # Convert numeric values to appropriate types
271
+ if "train_steps" in saved_state:
272
+ try:
273
+ saved_state["train_steps"] = int(saved_state["train_steps"])
274
+ except (ValueError, TypeError):
275
+ saved_state["train_steps"] = default_state["train_steps"]
276
+ logger.warning("Invalid train_steps value, using default")
 
 
 
 
 
277
 
278
+ if "batch_size" in saved_state:
279
+ try:
280
+ saved_state["batch_size"] = int(saved_state["batch_size"])
281
+ except (ValueError, TypeError):
282
+ saved_state["batch_size"] = default_state["batch_size"]
283
+ logger.warning("Invalid batch_size value, using default")
284
+
285
+ if "learning_rate" in saved_state:
286
+ try:
287
+ saved_state["learning_rate"] = float(saved_state["learning_rate"])
288
+ except (ValueError, TypeError):
289
+ saved_state["learning_rate"] = default_state["learning_rate"]
290
+ logger.warning("Invalid learning_rate value, using default")
291
+
292
+ if "save_iterations" in saved_state:
293
+ try:
294
+ saved_state["save_iterations"] = int(saved_state["save_iterations"])
295
+ except (ValueError, TypeError):
296
+ saved_state["save_iterations"] = default_state["save_iterations"]
297
+ logger.warning("Invalid save_iterations value, using default")
298
+
299
+ # Make sure we have all keys (in case structure changed)
300
+ merged_state = default_state.copy()
301
+ merged_state.update({k: v for k, v in saved_state.items() if v is not None})
 
 
 
 
 
 
302
 
303
+ # Validate model_type is in available choices
304
+ if merged_state["model_type"] not in MODEL_TYPES:
305
+ # Try to map from internal name
306
+ model_found = False
307
+ for display_name, internal_name in MODEL_TYPES.items():
308
+ if internal_name == merged_state["model_type"]:
309
+ merged_state["model_type"] = display_name
310
+ model_found = True
311
+ break
312
+ # If still not found, use default
313
+ if not model_found:
314
+ merged_state["model_type"] = default_state["model_type"]
315
+ logger.warning(f"Invalid model type in saved state, using default")
316
+
317
+ # Validate model_version is appropriate for model_type
318
+ if "model_type" in merged_state and "model_version" in merged_state:
319
+ model_internal_type = MODEL_TYPES.get(merged_state["model_type"])
320
+ if model_internal_type:
321
+ valid_versions = MODEL_VERSIONS.get(model_internal_type, {}).keys()
322
+ if merged_state["model_version"] not in valid_versions:
323
+ # Set to default for this model type
324
+ from vms.ui.project.tabs.train_tab import TrainTab
325
+ train_tab = TrainTab(None) # Temporary instance just for the helper method
326
+ merged_state["model_version"] = train_tab.get_default_model_version(saved_state["model_type"])
327
+ logger.warning(f"Invalid model version for {merged_state['model_type']}, using default")
328
 
329
+ # Validate training_type is in available choices
330
+ if merged_state["training_type"] not in TRAINING_TYPES:
331
+ # Try to map from internal name
332
+ training_found = False
333
+ for display_name, internal_name in TRAINING_TYPES.items():
334
+ if internal_name == merged_state["training_type"]:
335
+ merged_state["training_type"] = display_name
336
+ training_found = True
337
+ break
338
+ # If still not found, use default
339
+ if not training_found:
340
+ merged_state["training_type"] = default_state["training_type"]
341
+ logger.warning(f"Invalid training type in saved state, using default")
342
 
343
+ # Validate training_preset is in available choices
344
+ if merged_state["training_preset"] not in TRAINING_PRESETS:
345
+ merged_state["training_preset"] = default_state["training_preset"]
346
+ logger.warning(f"Invalid training preset in saved state, using default")
347
+
348
+ # Validate lora_rank is in allowed values
349
+ if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]:
350
+ merged_state["lora_rank"] = default_state["lora_rank"]
351
+ logger.warning(f"Invalid lora_rank in saved state, using default")
352
+
353
+ # Validate lora_alpha is in allowed values
354
+ if merged_state.get("lora_alpha") not in ["16", "32", "64", "128", "256", "512", "1024"]:
355
+ merged_state["lora_alpha"] = default_state["lora_alpha"]
356
+ logger.warning(f"Invalid lora_alpha in saved state, using default")
357
+
358
+ return merged_state
359
+ except Exception as e:
360
+ logger.error(f"Error loading UI state: {str(e)}")
361
+ # If anything goes wrong, backup and recreate
362
+ self._backup_and_recreate_ui_state(ui_state_file, default_state)
363
+ return default_state
364
 
365
  def ensure_valid_ui_state_file(self):
366
  """Ensure UI state file exists and is valid JSON"""
vms/ui/project/tabs/preview_tab.py CHANGED
@@ -298,7 +298,7 @@ class PreviewTab(BaseTab):
298
  # Update model_version choices when model_type changes or tab is selected
299
  if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
300
  self.app.tabs_component.select(
301
- fn=self.sync_model_type_and_verions,
302
  inputs=[],
303
  outputs=[
304
  self.components["model_type"],
@@ -391,7 +391,7 @@ class PreviewTab(BaseTab):
391
  self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image)
392
  }
393
 
394
- def sync_model_type_and_verions(self) -> Tuple[str, str]:
395
  """Sync model type with training tab when preview tab is selected and update model version choices"""
396
  model_type = self.get_default_model_type()
397
  model_version = ""
@@ -401,19 +401,15 @@ class PreviewTab(BaseTab):
401
  preview_state = ui_state.get("preview", {})
402
  model_version = preview_state.get("model_version", "")
403
 
 
404
  if not model_version:
405
- # Format it as a display choice
406
  internal_type = MODEL_TYPES.get(model_type)
407
  if internal_type and internal_type in MODEL_VERSIONS:
408
- first_version = next(iter(MODEL_VERSIONS[internal_type].keys()), "")
409
- if first_version:
410
- model_version_info = MODEL_VERSIONS[internal_type][first_version]
411
- model_version = f"{first_version} - {model_version_info.get('name', '')}"
412
 
413
- # If we couldn't get it, use default
414
- if not model_version:
415
- model_version = self.get_default_model_version(model_type)
416
-
417
  return model_type, model_version
418
 
419
  def update_resolution(self, preset: str) -> Tuple[int, int, float]:
 
298
  # Update model_version choices when model_type changes or tab is selected
299
  if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
300
  self.app.tabs_component.select(
301
+ fn=self.sync_model_type_and_versions,
302
  inputs=[],
303
  outputs=[
304
  self.components["model_type"],
 
391
  self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image)
392
  }
393
 
394
+ def sync_model_type_and_versions(self) -> Tuple[str, str]:
395
  """Sync model type with training tab when preview tab is selected and update model version choices"""
396
  model_type = self.get_default_model_type()
397
  model_version = ""
 
401
  preview_state = ui_state.get("preview", {})
402
  model_version = preview_state.get("model_version", "")
403
 
404
+ # If no model version specified or invalid, use default
405
  if not model_version:
406
+ # Get the internal model type
407
  internal_type = MODEL_TYPES.get(model_type)
408
  if internal_type and internal_type in MODEL_VERSIONS:
409
+ versions = list(MODEL_VERSIONS[internal_type].keys())
410
+ if versions:
411
+ model_version = versions[0]
 
412
 
 
 
 
 
413
  return model_type, model_version
414
 
415
  def update_resolution(self, preset: str) -> Tuple[int, int, float]:
vms/ui/project/tabs/train_tab.py CHANGED
@@ -69,7 +69,7 @@ class TrainTab(BaseTab):
69
  # Get model versions for the default model type
70
  default_model_versions = self.get_model_version_choices(default_model_type)
71
  default_model_version = self.get_default_model_version(default_model_type)
72
-
73
  self.components["model_version"] = gr.Dropdown(
74
  choices=default_model_versions,
75
  label="Model Version",
@@ -214,6 +214,37 @@ class TrainTab(BaseTab):
214
 
215
  return tab
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  def connect_events(self) -> None:
218
  """Connect event handlers to UI components"""
219
  # Model type change event - Update model version dropdown choices
@@ -222,8 +253,8 @@ class TrainTab(BaseTab):
222
  inputs=[self.components["model_type"]],
223
  outputs=[self.components["model_version"]]
224
  ).then(
225
- fn=lambda v: self.app.update_ui_state(model_type=v),
226
- inputs=[self.components["model_type"]],
227
  outputs=[]
228
  ).then(
229
  # Use get_model_info instead of update_model_info
@@ -234,8 +265,8 @@ class TrainTab(BaseTab):
234
 
235
  # Model version change event
236
  self.components["model_version"].change(
237
- fn=lambda v: self.app.update_ui_state(model_version=v),
238
- inputs=[self.components["model_version"]],
239
  outputs=[]
240
  )
241
 
@@ -399,10 +430,13 @@ class TrainTab(BaseTab):
399
  """Update model version choices based on selected model type"""
400
  model_versions = self.get_model_version_choices(model_type)
401
  default_version = self.get_default_model_version(model_type)
 
 
 
402
 
403
  # Update the model_version dropdown with new choices and default value
404
  return gr.Dropdown(choices=model_versions, value=default_version)
405
-
406
  def handle_training_start(
407
  self, preset, model_type, model_version, training_type,
408
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
@@ -477,22 +511,24 @@ class TrainTab(BaseTab):
477
  if not internal_type or internal_type not in MODEL_VERSIONS:
478
  return []
479
 
480
- # Get versions and return them as choices
481
- versions = MODEL_VERSIONS.get(internal_type, {})
482
- return list(versions.keys())
483
 
484
  def get_default_model_version(self, model_type: str) -> str:
485
  """Get default model version for the given model type"""
486
  # Convert UI display name to internal name
487
  internal_type = MODEL_TYPES.get(model_type)
 
488
  if not internal_type or internal_type not in MODEL_VERSIONS:
489
  return ""
490
 
491
  # Get the first version available for this model type
492
  versions = MODEL_VERSIONS.get(internal_type, {})
493
  if versions:
494
- return next(iter(versions.keys()))
495
-
 
496
  return ""
497
 
498
  def update_model_info(self, model_type: str, training_type: str) -> Dict:
 
69
  # Get model versions for the default model type
70
  default_model_versions = self.get_model_version_choices(default_model_type)
71
  default_model_version = self.get_default_model_version(default_model_type)
72
+ print(f"default_model_version(default_model_type) = {default_model_version}")
73
  self.components["model_version"] = gr.Dropdown(
74
  choices=default_model_versions,
75
  label="Model Version",
 
214
 
215
  return tab
216
 
217
+ def update_model_type_and_version(self, model_type: str, model_version: str):
218
+ """Update both model type and version together to keep them in sync"""
219
+ # Get internal model type
220
+ internal_type = MODEL_TYPES.get(model_type)
221
+
222
+ # Make sure model_version is valid for this model type
223
+ if internal_type and internal_type in MODEL_VERSIONS:
224
+ valid_versions = list(MODEL_VERSIONS[internal_type].keys())
225
+ if not model_version or model_version not in valid_versions:
226
+ if valid_versions:
227
+ model_version = valid_versions[0]
228
+
229
+ # Update UI state with both values to keep them in sync
230
+ self.app.update_ui_state(model_type=model_type, model_version=model_version)
231
+ return None
232
+
233
+ def save_model_version(self, model_type: str, model_version: str):
234
+ """Save model version ensuring it's consistent with model type"""
235
+ internal_type = MODEL_TYPES.get(model_type)
236
+
237
+ # Verify the model_version is compatible with the current model_type
238
+ if internal_type and internal_type in MODEL_VERSIONS:
239
+ valid_versions = MODEL_VERSIONS[internal_type].keys()
240
+ if model_version not in valid_versions:
241
+ # Don't save incompatible version
242
+ return None
243
+
244
+ # Save the model version along with current model type to ensure consistency
245
+ self.app.update_ui_state(model_type=model_type, model_version=model_version)
246
+ return None
247
+
248
  def connect_events(self) -> None:
249
  """Connect event handlers to UI components"""
250
  # Model type change event - Update model version dropdown choices
 
253
  inputs=[self.components["model_type"]],
254
  outputs=[self.components["model_version"]]
255
  ).then(
256
+ fn=self.update_model_type_and_version, # Add this new function
257
+ inputs=[self.components["model_type"], self.components["model_version"]],
258
  outputs=[]
259
  ).then(
260
  # Use get_model_info instead of update_model_info
 
265
 
266
  # Model version change event
267
  self.components["model_version"].change(
268
+ fn=self.save_model_version, # Replace with this new function
269
+ inputs=[self.components["model_type"], self.components["model_version"]],
270
  outputs=[]
271
  )
272
 
 
430
  """Update model version choices based on selected model type"""
431
  model_versions = self.get_model_version_choices(model_type)
432
  default_version = self.get_default_model_version(model_type)
433
+ print(f"update_model_versions({model_type}): default_version = {default_version}")
434
+ # Update UI state with proper model_type first (add this line)
435
+ self.app.update_ui_state(model_type=model_type)
436
 
437
  # Update the model_version dropdown with new choices and default value
438
  return gr.Dropdown(choices=model_versions, value=default_version)
439
+
440
  def handle_training_start(
441
  self, preset, model_type, model_version, training_type,
442
  lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
 
511
  if not internal_type or internal_type not in MODEL_VERSIONS:
512
  return []
513
 
514
+ # Return just the model IDs without formatting
515
+ return list(MODEL_VERSIONS.get(internal_type, {}).keys())
516
+
517
 
518
  def get_default_model_version(self, model_type: str) -> str:
519
  """Get default model version for the given model type"""
520
  # Convert UI display name to internal name
521
  internal_type = MODEL_TYPES.get(model_type)
522
+ print(f"get_default_model_version({model_type}) = {internal_type}")
523
  if not internal_type or internal_type not in MODEL_VERSIONS:
524
  return ""
525
 
526
  # Get the first version available for this model type
527
  versions = MODEL_VERSIONS.get(internal_type, {})
528
  if versions:
529
+ model_versions = list(versions.keys())
530
+ if model_versions:
531
+ return model_versions[0]
532
  return ""
533
 
534
  def update_model_info(self, model_type: str, training_type: str) -> Dict: