Spaces:
Running
Running
Commit
·
32b4f0f
1
Parent(s):
4905a7d
makign some fixes
Browse files- app.py +41 -77
- config.py +4 -1
- finetrainers/dataset.py +3 -3
- training_log_parser.py +6 -2
- training_service.py +2 -2
app.py
CHANGED
@@ -36,7 +36,7 @@ from splitting_service import SplittingService
|
|
36 |
from import_service import ImportService
|
37 |
from config import (
|
38 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
|
39 |
-
TRAINING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
|
40 |
DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, TRAINING_BUCKETS
|
41 |
)
|
42 |
from utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time
|
@@ -134,6 +134,9 @@ class VideoTrainerUI:
|
|
134 |
self.splitter.processing = False
|
135 |
status_messages["splitting"] = "Scene detection stopped"
|
136 |
|
|
|
|
|
|
|
137 |
# Clear all data directories
|
138 |
for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
|
139 |
MODEL_PATH, OUTPUT_PATH]:
|
@@ -258,15 +261,11 @@ class VideoTrainerUI:
|
|
258 |
# Only return name and status columns for display
|
259 |
return [[file[0], file[1]] for file in files]
|
260 |
|
261 |
-
def update_training_buttons(self,
|
262 |
"""Update training control buttons based on state"""
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
is_training = False
|
267 |
-
is_paused = training_state["status"] == "paused"
|
268 |
-
is_completed = training_state["status"] in ["completed", "error", "stopped"]
|
269 |
-
#print(f"update_training_buttons: is_training = {is_training}, is_paused = {is_paused}, is_completed = {is_completed}")
|
270 |
return {
|
271 |
"start_btn": gr.Button(
|
272 |
interactive=not is_training and not is_paused,
|
@@ -283,32 +282,20 @@ class VideoTrainerUI:
|
|
283 |
)
|
284 |
}
|
285 |
|
286 |
-
def handle_training_complete(self):
|
287 |
-
"""Handle training completion"""
|
288 |
-
# Reset button states
|
289 |
-
return self.update_training_buttons({
|
290 |
-
"status": "completed",
|
291 |
-
"progress": "100%",
|
292 |
-
"current_step": 0,
|
293 |
-
"total_steps": 0
|
294 |
-
})
|
295 |
-
|
296 |
def handle_pause_resume(self):
|
|
|
297 |
|
298 |
-
status
|
299 |
-
|
300 |
-
if status["status"] == "paused":
|
301 |
-
result = self.trainer.resume_training()
|
302 |
-
new_state = {"status": "training"}
|
303 |
else:
|
304 |
-
|
305 |
-
|
306 |
-
return (
|
307 |
-
|
308 |
-
|
309 |
-
)
|
|
|
310 |
|
311 |
-
|
312 |
def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
313 |
"""Handle selection of both video clips and images"""
|
314 |
try:
|
@@ -623,15 +610,10 @@ class VideoTrainerUI:
|
|
623 |
return f"Error during scene detection: {str(e)}"
|
624 |
|
625 |
|
626 |
-
def
|
627 |
-
|
628 |
-
status = self.trainer.get_status()
|
629 |
logs = self.trainer.get_logs()
|
630 |
|
631 |
-
status_update = status["message"]
|
632 |
-
|
633 |
-
# print(f"refresh_training_status_and_logs: ", status)
|
634 |
-
|
635 |
# Parse new log lines
|
636 |
if logs:
|
637 |
last_state = None
|
@@ -639,42 +621,28 @@ class VideoTrainerUI:
|
|
639 |
state_update = self.log_parser.parse_line(line)
|
640 |
if state_update:
|
641 |
last_state = state_update
|
642 |
-
print("last_state = ", last_state)
|
643 |
|
644 |
if last_state:
|
645 |
ui_updates = self.update_training_ui(last_state)
|
646 |
-
|
647 |
-
|
648 |
-
return (status_update, logs)
|
649 |
-
|
650 |
-
def refresh_training_status(self):
|
651 |
-
"""Refresh training status and update UI"""
|
652 |
-
status, logs = self.refresh_training_status_and_logs()
|
653 |
|
654 |
# Parse status for training state
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
if is_completed:
|
664 |
-
button_updates = self.handle_training_complete()
|
665 |
-
return (
|
666 |
-
status,
|
667 |
-
logs,
|
668 |
-
*button_updates.values()
|
669 |
-
)
|
670 |
-
|
671 |
-
# Update based on current training state
|
672 |
-
button_updates = self.update_training_buttons(current_state)
|
673 |
return (
|
674 |
-
|
675 |
logs,
|
676 |
-
*
|
677 |
)
|
|
|
|
|
|
|
|
|
678 |
|
679 |
def refresh_dataset(self):
|
680 |
"""Refresh all dynamic lists and training state"""
|
@@ -1141,22 +1109,18 @@ class VideoTrainerUI:
|
|
1141 |
],
|
1142 |
outputs=[status_box, log_box]
|
1143 |
).success(
|
1144 |
-
fn=
|
1145 |
-
outputs=[start_btn, stop_btn, pause_resume_btn]
|
1146 |
)
|
1147 |
|
1148 |
-
|
1149 |
pause_resume_btn.click(
|
1150 |
fn=self.handle_pause_resume,
|
1151 |
outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
|
1152 |
)
|
1153 |
|
1154 |
stop_btn.click(
|
1155 |
-
fn=self.
|
1156 |
-
outputs=[status_box, log_box]
|
1157 |
-
).success(
|
1158 |
-
fn=self.handle_training_complete,
|
1159 |
-
outputs=[start_btn, stop_btn, pause_resume_btn]
|
1160 |
)
|
1161 |
|
1162 |
def handle_global_stop():
|
@@ -1218,12 +1182,12 @@ class VideoTrainerUI:
|
|
1218 |
timer = gr.Timer(value=1)
|
1219 |
timer.tick(
|
1220 |
fn=lambda: (
|
1221 |
-
self.
|
1222 |
),
|
1223 |
outputs=[
|
1224 |
status_box,
|
1225 |
log_box,
|
1226 |
-
|
1227 |
stop_btn,
|
1228 |
pause_resume_btn
|
1229 |
]
|
@@ -1239,7 +1203,7 @@ class VideoTrainerUI:
|
|
1239 |
]
|
1240 |
)
|
1241 |
|
1242 |
-
timer = gr.Timer(value=
|
1243 |
timer.tick(
|
1244 |
fn=lambda: self.update_titles(),
|
1245 |
outputs=[
|
|
|
36 |
from import_service import ImportService
|
37 |
from config import (
|
38 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
|
39 |
+
TRAINING_PATH, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
|
40 |
DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, TRAINING_BUCKETS
|
41 |
)
|
42 |
from utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time
|
|
|
134 |
self.splitter.processing = False
|
135 |
status_messages["splitting"] = "Scene detection stopped"
|
136 |
|
137 |
+
if LOG_FILE_PATH.exists():
|
138 |
+
LOG_FILE_PATH.unlink()
|
139 |
+
|
140 |
# Clear all data directories
|
141 |
for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
|
142 |
MODEL_PATH, OUTPUT_PATH]:
|
|
|
261 |
# Only return name and status columns for display
|
262 |
return [[file[0], file[1]] for file in files]
|
263 |
|
264 |
+
def update_training_buttons(self, status: str) -> Dict:
|
265 |
"""Update training control buttons based on state"""
|
266 |
+
is_training = status in ["training", "initializing"]
|
267 |
+
is_paused = status == "paused"
|
268 |
+
is_completed = status in ["completed", "error", "stopped"]
|
|
|
|
|
|
|
|
|
269 |
return {
|
270 |
"start_btn": gr.Button(
|
271 |
interactive=not is_training and not is_paused,
|
|
|
282 |
)
|
283 |
}
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
def handle_pause_resume(self):
|
286 |
+
status, _, _ = self.get_latest_status_message_and_logs()
|
287 |
|
288 |
+
if status == "paused":
|
289 |
+
self.trainer.resume_training()
|
|
|
|
|
|
|
290 |
else:
|
291 |
+
self.trainer.pause_training()
|
292 |
+
|
293 |
+
return self.get_latest_status_message_logs_and_button_labels()
|
294 |
+
|
295 |
+
def handle_stop(self):
|
296 |
+
self.trainer.stop_training()
|
297 |
+
return self.get_latest_status_message_logs_and_button_labels()
|
298 |
|
|
|
299 |
def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
300 |
"""Handle selection of both video clips and images"""
|
301 |
try:
|
|
|
610 |
return f"Error during scene detection: {str(e)}"
|
611 |
|
612 |
|
613 |
+
def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
|
614 |
+
state = self.trainer.get_status()
|
|
|
615 |
logs = self.trainer.get_logs()
|
616 |
|
|
|
|
|
|
|
|
|
617 |
# Parse new log lines
|
618 |
if logs:
|
619 |
last_state = None
|
|
|
621 |
state_update = self.log_parser.parse_line(line)
|
622 |
if state_update:
|
623 |
last_state = state_update
|
|
|
624 |
|
625 |
if last_state:
|
626 |
ui_updates = self.update_training_ui(last_state)
|
627 |
+
state["message"] = ui_updates.get("status_box", state["message"])
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
|
629 |
# Parse status for training state
|
630 |
+
if "completed" in state["message"].lower():
|
631 |
+
state["status"] = "completed"
|
632 |
+
|
633 |
+
return (state["status"], state["message"], logs)
|
634 |
+
|
635 |
+
def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]:
|
636 |
+
status, message, logs = self.get_latest_status_message_and_logs()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
637 |
return (
|
638 |
+
message,
|
639 |
logs,
|
640 |
+
*self.update_training_buttons(status).values()
|
641 |
)
|
642 |
+
|
643 |
+
def get_latest_button_labels(self) -> Tuple[Any, Any, Any]:
|
644 |
+
status, message, logs = self.get_latest_status_message_and_logs()
|
645 |
+
return self.update_training_buttons(status).values()
|
646 |
|
647 |
def refresh_dataset(self):
|
648 |
"""Refresh all dynamic lists and training state"""
|
|
|
1109 |
],
|
1110 |
outputs=[status_box, log_box]
|
1111 |
).success(
|
1112 |
+
fn=self.get_latest_status_message_logs_and_button_labels,
|
1113 |
+
outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
|
1114 |
)
|
1115 |
|
|
|
1116 |
pause_resume_btn.click(
|
1117 |
fn=self.handle_pause_resume,
|
1118 |
outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
|
1119 |
)
|
1120 |
|
1121 |
stop_btn.click(
|
1122 |
+
fn=self.handle_stop,
|
1123 |
+
outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
|
|
|
|
|
|
|
1124 |
)
|
1125 |
|
1126 |
def handle_global_stop():
|
|
|
1182 |
timer = gr.Timer(value=1)
|
1183 |
timer.tick(
|
1184 |
fn=lambda: (
|
1185 |
+
self.get_latest_status_message_logs_and_button_labels()
|
1186 |
),
|
1187 |
outputs=[
|
1188 |
status_box,
|
1189 |
log_box,
|
1190 |
+
start_btn,
|
1191 |
stop_btn,
|
1192 |
pause_resume_btn
|
1193 |
]
|
|
|
1203 |
]
|
1204 |
)
|
1205 |
|
1206 |
+
timer = gr.Timer(value=6)
|
1207 |
timer.tick(
|
1208 |
fn=lambda: self.update_titles(),
|
1209 |
outputs=[
|
config.py
CHANGED
@@ -16,7 +16,8 @@ STAGING_PATH = STORAGE_PATH / "staging" # This is where files
|
|
16 |
TRAINING_PATH = STORAGE_PATH / "training" # Folder containing the final training dataset
|
17 |
TRAINING_VIDEOS_PATH = TRAINING_PATH / "videos" # Captioned clips ready for training
|
18 |
MODEL_PATH = STORAGE_PATH / "model" # Model checkpoints and files
|
19 |
-
OUTPUT_PATH = STORAGE_PATH / "output"
|
|
|
20 |
|
21 |
# On the production server we can afford to preload the big model
|
22 |
PRELOAD_CAPTIONING_MODEL = parse_bool_env(os.environ.get('PRELOAD_CAPTIONING_MODEL'))
|
@@ -66,6 +67,8 @@ TRAINING_HEIGHT = 512 # 32 * 16
|
|
66 |
# right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
|
67 |
|
68 |
TRAINING_BUCKETS = [
|
|
|
|
|
69 |
(8 * 2 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 16 + 1
|
70 |
(8 * 4 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 32 + 1
|
71 |
(8 * 6 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 48 + 1
|
|
|
16 |
TRAINING_PATH = STORAGE_PATH / "training" # Folder containing the final training dataset
|
17 |
TRAINING_VIDEOS_PATH = TRAINING_PATH / "videos" # Captioned clips ready for training
|
18 |
MODEL_PATH = STORAGE_PATH / "model" # Model checkpoints and files
|
19 |
+
OUTPUT_PATH = STORAGE_PATH / "output" # Training outputs and logs
|
20 |
+
LOG_FILE_PATH = OUTPUT_PATH / "last_session.log"
|
21 |
|
22 |
# On the production server we can afford to preload the big model
|
23 |
PRELOAD_CAPTIONING_MODEL = parse_bool_env(os.environ.get('PRELOAD_CAPTIONING_MODEL'))
|
|
|
67 |
# right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
|
68 |
|
69 |
TRAINING_BUCKETS = [
|
70 |
+
(1, TRAINING_HEIGHT, TRAINING_WIDTH), # 1
|
71 |
+
(8 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 8 + 1
|
72 |
(8 * 2 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 16 + 1
|
73 |
(8 * 4 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 32 + 1
|
74 |
(8 * 6 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 48 + 1
|
finetrainers/dataset.py
CHANGED
@@ -266,9 +266,9 @@ class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset):
|
|
266 |
def _preprocess_video(self, path: Path) -> torch.Tensor:
|
267 |
video_reader = decord.VideoReader(uri=path.as_posix())
|
268 |
video_num_frames = len(video_reader)
|
269 |
-
print(f"ImageOrVideoDatasetWithResizing: self.resolution_buckets = ", self.resolution_buckets)
|
270 |
-
print(f"ImageOrVideoDatasetWithResizing: self.max_num_frames = ", self.max_num_frames)
|
271 |
-
print(f"ImageOrVideoDatasetWithResizing: video_num_frames = ", video_num_frames)
|
272 |
|
273 |
video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
|
274 |
|
|
|
266 |
def _preprocess_video(self, path: Path) -> torch.Tensor:
|
267 |
video_reader = decord.VideoReader(uri=path.as_posix())
|
268 |
video_num_frames = len(video_reader)
|
269 |
+
#print(f"ImageOrVideoDatasetWithResizing: self.resolution_buckets = ", self.resolution_buckets)
|
270 |
+
#print(f"ImageOrVideoDatasetWithResizing: self.max_num_frames = ", self.max_num_frames)
|
271 |
+
#print(f"ImageOrVideoDatasetWithResizing: video_num_frames = ", video_num_frames)
|
272 |
|
273 |
video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
|
274 |
|
training_log_parser.py
CHANGED
@@ -66,14 +66,18 @@ class TrainingLogParser:
|
|
66 |
"""Parse a single log line and update state"""
|
67 |
try:
|
68 |
# For debugging
|
69 |
-
logger.info(f"Parsing line: {line[:100]}...")
|
70 |
|
71 |
# Training step progress line example:
|
72 |
# Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7]
|
|
|
|
|
|
|
|
|
73 |
if "Training steps:" in line:
|
74 |
# Set status to training if we see this
|
75 |
self.state.status = "training"
|
76 |
-
print("setting status to 'training'")
|
77 |
if not self.state.start_time:
|
78 |
self.state.start_time = datetime.now()
|
79 |
|
|
|
66 |
"""Parse a single log line and update state"""
|
67 |
try:
|
68 |
# For debugging
|
69 |
+
#logger.info(f"Parsing line: {line[:100]}...")
|
70 |
|
71 |
# Training step progress line example:
|
72 |
# Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7]
|
73 |
+
|
74 |
+
if ("Started training" in line) or (("Starting training" in line):
|
75 |
+
self.state.status = "training"
|
76 |
+
|
77 |
if "Training steps:" in line:
|
78 |
# Set status to training if we see this
|
79 |
self.state.status = "training"
|
80 |
+
#print("setting status to 'training'")
|
81 |
if not self.state.start_time:
|
82 |
self.state.start_time = datetime.now()
|
83 |
|
training_service.py
CHANGED
@@ -19,7 +19,7 @@ import select
|
|
19 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
20 |
|
21 |
from huggingface_hub import upload_folder, create_repo
|
22 |
-
from config import TrainingConfig, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
|
23 |
from utils import make_archive, parse_training_log, is_image_file, is_video_file
|
24 |
from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
|
25 |
|
@@ -29,7 +29,7 @@ logging.basicConfig(
|
|
29 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
30 |
handlers=[
|
31 |
logging.StreamHandler(sys.stdout),
|
32 |
-
logging.FileHandler(str(
|
33 |
]
|
34 |
)
|
35 |
logger = logging.getLogger(__name__)
|
|
|
19 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
20 |
|
21 |
from huggingface_hub import upload_folder, create_repo
|
22 |
+
from config import TrainingConfig, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
|
23 |
from utils import make_archive, parse_training_log, is_image_file, is_video_file
|
24 |
from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
|
25 |
|
|
|
29 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
30 |
handlers=[
|
31 |
logging.StreamHandler(sys.stdout),
|
32 |
+
logging.FileHandler(str(LOG_FILE_PATH))
|
33 |
]
|
34 |
)
|
35 |
logger = logging.getLogger(__name__)
|