Spaces:
Running
Running
Commit
·
743eda6
1
Parent(s):
446e79f
improve UI persistence
Browse files- app.py +21 -8
- vms/training_service.py +14 -3
app.py
CHANGED
@@ -144,7 +144,9 @@ class VideoTrainerUI:
|
|
144 |
"""Load UI state values for initializing form fields"""
|
145 |
ui_state = self.trainer.load_ui_state()
|
146 |
|
147 |
-
#
|
|
|
|
|
148 |
ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
|
149 |
ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
|
150 |
ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
|
@@ -866,9 +868,12 @@ class VideoTrainerUI:
|
|
866 |
)
|
867 |
|
868 |
def update_training_params(self, preset_name: str) -> Tuple:
|
869 |
-
"""Update UI components based on selected preset"""
|
870 |
preset = TRAINING_PRESETS[preset_name]
|
871 |
|
|
|
|
|
|
|
872 |
# Find the display name that maps to our model type
|
873 |
model_display_name = next(
|
874 |
key for key, value in MODEL_TYPES.items()
|
@@ -888,14 +893,22 @@ class VideoTrainerUI:
|
|
888 |
info_text = f"{description}{bucket_info}"
|
889 |
|
890 |
# Return values in the same order as the output components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
891 |
return (
|
892 |
model_display_name,
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
info_text
|
900 |
)
|
901 |
|
|
|
144 |
"""Load UI state values for initializing form fields"""
|
145 |
ui_state = self.trainer.load_ui_state()
|
146 |
|
147 |
+
# Ensure proper type conversion for numeric values
|
148 |
+
ui_state["lora_rank"] = ui_state.get("lora_rank", "128")
|
149 |
+
ui_state["lora_alpha"] = ui_state.get("lora_alpha", "128")
|
150 |
ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
|
151 |
ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
|
152 |
ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
|
|
|
868 |
)
|
869 |
|
870 |
def update_training_params(self, preset_name: str) -> Tuple:
|
871 |
+
"""Update UI components based on selected preset while preserving custom settings"""
|
872 |
preset = TRAINING_PRESETS[preset_name]
|
873 |
|
874 |
+
# Load current UI state to check if user has customized values
|
875 |
+
current_state = self.load_ui_values()
|
876 |
+
|
877 |
# Find the display name that maps to our model type
|
878 |
model_display_name = next(
|
879 |
key for key, value in MODEL_TYPES.items()
|
|
|
893 |
info_text = f"{description}{bucket_info}"
|
894 |
|
895 |
# Return values in the same order as the output components
|
896 |
+
# Use preset defaults but preserve user-modified values if they exist
|
897 |
+
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
|
898 |
+
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
|
899 |
+
num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
|
900 |
+
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
|
901 |
+
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
|
902 |
+
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
|
903 |
+
|
904 |
return (
|
905 |
model_display_name,
|
906 |
+
lora_rank_val,
|
907 |
+
lora_alpha_val,
|
908 |
+
num_epochs_val,
|
909 |
+
batch_size_val,
|
910 |
+
learning_rate_val,
|
911 |
+
save_iterations_val,
|
912 |
info_text
|
913 |
)
|
914 |
|
vms/training_service.py
CHANGED
@@ -114,19 +114,30 @@ class TrainingService:
|
|
114 |
"model_type": list(MODEL_TYPES.keys())[0],
|
115 |
"lora_rank": "128",
|
116 |
"lora_alpha": "128",
|
117 |
-
"num_epochs":
|
118 |
"batch_size": 1,
|
119 |
"learning_rate": 3e-5,
|
120 |
-
"save_iterations":
|
121 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
122 |
}
|
123 |
|
124 |
if not ui_state_file.exists():
|
125 |
return default_state
|
126 |
-
|
127 |
try:
|
128 |
with open(ui_state_file, 'r') as f:
|
129 |
saved_state = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
# Make sure we have all keys (in case structure changed)
|
131 |
merged_state = default_state.copy()
|
132 |
merged_state.update(saved_state)
|
|
|
114 |
"model_type": list(MODEL_TYPES.keys())[0],
|
115 |
"lora_rank": "128",
|
116 |
"lora_alpha": "128",
|
117 |
+
"num_epochs": 50,
|
118 |
"batch_size": 1,
|
119 |
"learning_rate": 3e-5,
|
120 |
+
"save_iterations": 200,
|
121 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
122 |
}
|
123 |
|
124 |
if not ui_state_file.exists():
|
125 |
return default_state
|
126 |
+
|
127 |
try:
|
128 |
with open(ui_state_file, 'r') as f:
|
129 |
saved_state = json.load(f)
|
130 |
+
|
131 |
+
# Convert numeric values to appropriate types
|
132 |
+
if "num_epochs" in saved_state:
|
133 |
+
saved_state["num_epochs"] = int(saved_state["num_epochs"])
|
134 |
+
if "batch_size" in saved_state:
|
135 |
+
saved_state["batch_size"] = int(saved_state["batch_size"])
|
136 |
+
if "learning_rate" in saved_state:
|
137 |
+
saved_state["learning_rate"] = float(saved_state["learning_rate"])
|
138 |
+
if "save_iterations" in saved_state:
|
139 |
+
saved_state["save_iterations"] = int(saved_state["save_iterations"])
|
140 |
+
|
141 |
# Make sure we have all keys (in case structure changed)
|
142 |
merged_state = default_state.copy()
|
143 |
merged_state.update(saved_state)
|