jbilcke-hf HF Staff commited on
Commit
743eda6
·
1 Parent(s): 446e79f

improve UI persistence

Browse files
Files changed (2) hide show
  1. app.py +21 -8
  2. vms/training_service.py +14 -3
app.py CHANGED
@@ -144,7 +144,9 @@ class VideoTrainerUI:
144
  """Load UI state values for initializing form fields"""
145
  ui_state = self.trainer.load_ui_state()
146
 
147
- # Convert types as needed since JSON stores everything as strings
 
 
148
  ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
149
  ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
150
  ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
@@ -866,9 +868,12 @@ class VideoTrainerUI:
866
  )
867
 
868
  def update_training_params(self, preset_name: str) -> Tuple:
869
- """Update UI components based on selected preset"""
870
  preset = TRAINING_PRESETS[preset_name]
871
 
 
 
 
872
  # Find the display name that maps to our model type
873
  model_display_name = next(
874
  key for key, value in MODEL_TYPES.items()
@@ -888,14 +893,22 @@ class VideoTrainerUI:
888
  info_text = f"{description}{bucket_info}"
889
 
890
  # Return values in the same order as the output components
 
 
 
 
 
 
 
 
891
  return (
892
  model_display_name,
893
- preset["lora_rank"],
894
- preset["lora_alpha"],
895
- preset["num_epochs"],
896
- preset["batch_size"],
897
- preset["learning_rate"],
898
- preset["save_iterations"],
899
  info_text
900
  )
901
 
 
144
  """Load UI state values for initializing form fields"""
145
  ui_state = self.trainer.load_ui_state()
146
 
147
+ # Ensure proper type conversion for numeric values
148
+ ui_state["lora_rank"] = ui_state.get("lora_rank", "128")
149
+ ui_state["lora_alpha"] = ui_state.get("lora_alpha", "128")
150
  ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
151
  ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
152
  ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
 
868
  )
869
 
870
  def update_training_params(self, preset_name: str) -> Tuple:
871
+ """Update UI components based on selected preset while preserving custom settings"""
872
  preset = TRAINING_PRESETS[preset_name]
873
 
874
+ # Load current UI state to check if user has customized values
875
+ current_state = self.load_ui_values()
876
+
877
  # Find the display name that maps to our model type
878
  model_display_name = next(
879
  key for key, value in MODEL_TYPES.items()
 
893
  info_text = f"{description}{bucket_info}"
894
 
895
  # Return values in the same order as the output components
896
+ # Use preset defaults but preserve user-modified values if they exist
897
+ lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
898
+ lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
899
+ num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
900
+ batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
901
+ learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
902
+ save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
903
+
904
  return (
905
  model_display_name,
906
+ lora_rank_val,
907
+ lora_alpha_val,
908
+ num_epochs_val,
909
+ batch_size_val,
910
+ learning_rate_val,
911
+ save_iterations_val,
912
  info_text
913
  )
914
 
vms/training_service.py CHANGED
@@ -114,19 +114,30 @@ class TrainingService:
114
  "model_type": list(MODEL_TYPES.keys())[0],
115
  "lora_rank": "128",
116
  "lora_alpha": "128",
117
- "num_epochs": 70,
118
  "batch_size": 1,
119
  "learning_rate": 3e-5,
120
- "save_iterations": 500,
121
  "training_preset": list(TRAINING_PRESETS.keys())[0]
122
  }
123
 
124
  if not ui_state_file.exists():
125
  return default_state
126
-
127
  try:
128
  with open(ui_state_file, 'r') as f:
129
  saved_state = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
130
  # Make sure we have all keys (in case structure changed)
131
  merged_state = default_state.copy()
132
  merged_state.update(saved_state)
 
114
  "model_type": list(MODEL_TYPES.keys())[0],
115
  "lora_rank": "128",
116
  "lora_alpha": "128",
117
+ "num_epochs": 50,
118
  "batch_size": 1,
119
  "learning_rate": 3e-5,
120
+ "save_iterations": 200,
121
  "training_preset": list(TRAINING_PRESETS.keys())[0]
122
  }
123
 
124
  if not ui_state_file.exists():
125
  return default_state
126
+
127
  try:
128
  with open(ui_state_file, 'r') as f:
129
  saved_state = json.load(f)
130
+
131
+ # Convert numeric values to appropriate types
132
+ if "num_epochs" in saved_state:
133
+ saved_state["num_epochs"] = int(saved_state["num_epochs"])
134
+ if "batch_size" in saved_state:
135
+ saved_state["batch_size"] = int(saved_state["batch_size"])
136
+ if "learning_rate" in saved_state:
137
+ saved_state["learning_rate"] = float(saved_state["learning_rate"])
138
+ if "save_iterations" in saved_state:
139
+ saved_state["save_iterations"] = int(saved_state["save_iterations"])
140
+
141
  # Make sure we have all keys (in case structure changed)
142
  merged_state = default_state.copy()
143
  merged_state.update(saved_state)