jbilcke-hf HF Staff commited on
Commit
32b4f0f
·
1 Parent(s): 4905a7d

makign some fixes

Browse files
Files changed (5) hide show
  1. app.py +41 -77
  2. config.py +4 -1
  3. finetrainers/dataset.py +3 -3
  4. training_log_parser.py +6 -2
  5. training_service.py +2 -2
app.py CHANGED
@@ -36,7 +36,7 @@ from splitting_service import SplittingService
36
  from import_service import ImportService
37
  from config import (
38
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
39
- TRAINING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
40
  DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, TRAINING_BUCKETS
41
  )
42
  from utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time
@@ -134,6 +134,9 @@ class VideoTrainerUI:
134
  self.splitter.processing = False
135
  status_messages["splitting"] = "Scene detection stopped"
136
 
 
 
 
137
  # Clear all data directories
138
  for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
139
  MODEL_PATH, OUTPUT_PATH]:
@@ -258,15 +261,11 @@ class VideoTrainerUI:
258
  # Only return name and status columns for display
259
  return [[file[0], file[1]] for file in files]
260
 
261
- def update_training_buttons(self, training_state: Dict[str, Any]) -> Dict:
262
  """Update training control buttons based on state"""
263
- #print("update_training_buttons: training_state = ", training_state)
264
- is_training = training_state["status"] in ["training", "initializing"]
265
- if training_state["message"] == "No training in progress":
266
- is_training = False
267
- is_paused = training_state["status"] == "paused"
268
- is_completed = training_state["status"] in ["completed", "error", "stopped"]
269
- #print(f"update_training_buttons: is_training = {is_training}, is_paused = {is_paused}, is_completed = {is_completed}")
270
  return {
271
  "start_btn": gr.Button(
272
  interactive=not is_training and not is_paused,
@@ -283,32 +282,20 @@ class VideoTrainerUI:
283
  )
284
  }
285
 
286
- def handle_training_complete(self):
287
- """Handle training completion"""
288
- # Reset button states
289
- return self.update_training_buttons({
290
- "status": "completed",
291
- "progress": "100%",
292
- "current_step": 0,
293
- "total_steps": 0
294
- })
295
-
296
  def handle_pause_resume(self):
 
297
 
298
- status = self.trainer.get_status()
299
- print("handle_pause_resume: status = ", status)
300
- if status["status"] == "paused":
301
- result = self.trainer.resume_training()
302
- new_state = {"status": "training"}
303
  else:
304
- result = self.trainer.pause_training()
305
- new_state = {"status": "paused"}
306
- return (
307
- *result,
308
- *self.update_training_buttons(new_state).values()
309
- )
 
310
 
311
-
312
  def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str]]:
313
  """Handle selection of both video clips and images"""
314
  try:
@@ -623,15 +610,10 @@ class VideoTrainerUI:
623
  return f"Error during scene detection: {str(e)}"
624
 
625
 
626
- def refresh_training_status_and_logs(self):
627
- """Refresh all dynamic lists and training state"""
628
- status = self.trainer.get_status()
629
  logs = self.trainer.get_logs()
630
 
631
- status_update = status["message"]
632
-
633
- # print(f"refresh_training_status_and_logs: ", status)
634
-
635
  # Parse new log lines
636
  if logs:
637
  last_state = None
@@ -639,42 +621,28 @@ class VideoTrainerUI:
639
  state_update = self.log_parser.parse_line(line)
640
  if state_update:
641
  last_state = state_update
642
- print("last_state = ", last_state)
643
 
644
  if last_state:
645
  ui_updates = self.update_training_ui(last_state)
646
- status_update = ui_updates.get("status_box", status["message"])
647
-
648
- return (status_update, logs)
649
-
650
- def refresh_training_status(self):
651
- """Refresh training status and update UI"""
652
- status, logs = self.refresh_training_status_and_logs()
653
 
654
  # Parse status for training state
655
- is_completed = "completed" in status.lower() or "100.0%" in status
656
- current_state = {
657
- "status": "completed" if is_completed else "training",
658
- "message": status
659
- }
660
-
661
- #print("refresh_training_status: current_state = ", current_state)
662
-
663
- if is_completed:
664
- button_updates = self.handle_training_complete()
665
- return (
666
- status,
667
- logs,
668
- *button_updates.values()
669
- )
670
-
671
- # Update based on current training state
672
- button_updates = self.update_training_buttons(current_state)
673
  return (
674
- status,
675
  logs,
676
- *button_updates.values()
677
  )
 
 
 
 
678
 
679
  def refresh_dataset(self):
680
  """Refresh all dynamic lists and training state"""
@@ -1141,22 +1109,18 @@ class VideoTrainerUI:
1141
  ],
1142
  outputs=[status_box, log_box]
1143
  ).success(
1144
- fn=lambda: self.update_training_buttons(),
1145
- outputs=[start_btn, stop_btn, pause_resume_btn]
1146
  )
1147
 
1148
-
1149
  pause_resume_btn.click(
1150
  fn=self.handle_pause_resume,
1151
  outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1152
  )
1153
 
1154
  stop_btn.click(
1155
- fn=self.trainer.stop_training,
1156
- outputs=[status_box, log_box]
1157
- ).success(
1158
- fn=self.handle_training_complete,
1159
- outputs=[start_btn, stop_btn, pause_resume_btn]
1160
  )
1161
 
1162
  def handle_global_stop():
@@ -1218,12 +1182,12 @@ class VideoTrainerUI:
1218
  timer = gr.Timer(value=1)
1219
  timer.tick(
1220
  fn=lambda: (
1221
- self.refresh_training_status()
1222
  ),
1223
  outputs=[
1224
  status_box,
1225
  log_box,
1226
- start_btn,
1227
  stop_btn,
1228
  pause_resume_btn
1229
  ]
@@ -1239,7 +1203,7 @@ class VideoTrainerUI:
1239
  ]
1240
  )
1241
 
1242
- timer = gr.Timer(value=5)
1243
  timer.tick(
1244
  fn=lambda: self.update_titles(),
1245
  outputs=[
 
36
  from import_service import ImportService
37
  from config import (
38
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
39
+ TRAINING_PATH, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
40
  DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, TRAINING_BUCKETS
41
  )
42
  from utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time
 
134
  self.splitter.processing = False
135
  status_messages["splitting"] = "Scene detection stopped"
136
 
137
+ if LOG_FILE_PATH.exists():
138
+ LOG_FILE_PATH.unlink()
139
+
140
  # Clear all data directories
141
  for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
142
  MODEL_PATH, OUTPUT_PATH]:
 
261
  # Only return name and status columns for display
262
  return [[file[0], file[1]] for file in files]
263
 
264
+ def update_training_buttons(self, status: str) -> Dict:
265
  """Update training control buttons based on state"""
266
+ is_training = status in ["training", "initializing"]
267
+ is_paused = status == "paused"
268
+ is_completed = status in ["completed", "error", "stopped"]
 
 
 
 
269
  return {
270
  "start_btn": gr.Button(
271
  interactive=not is_training and not is_paused,
 
282
  )
283
  }
284
 
 
 
 
 
 
 
 
 
 
 
285
  def handle_pause_resume(self):
286
+ status, _, _ = self.get_latest_status_message_and_logs()
287
 
288
+ if status == "paused":
289
+ self.trainer.resume_training()
 
 
 
290
  else:
291
+ self.trainer.pause_training()
292
+
293
+ return self.get_latest_status_message_logs_and_button_labels()
294
+
295
+ def handle_stop(self):
296
+ self.trainer.stop_training()
297
+ return self.get_latest_status_message_logs_and_button_labels()
298
 
 
299
  def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str]]:
300
  """Handle selection of both video clips and images"""
301
  try:
 
610
  return f"Error during scene detection: {str(e)}"
611
 
612
 
613
+ def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
614
+ state = self.trainer.get_status()
 
615
  logs = self.trainer.get_logs()
616
 
 
 
 
 
617
  # Parse new log lines
618
  if logs:
619
  last_state = None
 
621
  state_update = self.log_parser.parse_line(line)
622
  if state_update:
623
  last_state = state_update
 
624
 
625
  if last_state:
626
  ui_updates = self.update_training_ui(last_state)
627
+ state["message"] = ui_updates.get("status_box", state["message"])
 
 
 
 
 
 
628
 
629
  # Parse status for training state
630
+ if "completed" in state["message"].lower():
631
+ state["status"] = "completed"
632
+
633
+ return (state["status"], state["message"], logs)
634
+
635
+ def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]:
636
+ status, message, logs = self.get_latest_status_message_and_logs()
 
 
 
 
 
 
 
 
 
 
 
637
  return (
638
+ message,
639
  logs,
640
+ *self.update_training_buttons(status).values()
641
  )
642
+
643
+ def get_latest_button_labels(self) -> Tuple[Any, Any, Any]:
644
+ status, message, logs = self.get_latest_status_message_and_logs()
645
+ return self.update_training_buttons(status).values()
646
 
647
  def refresh_dataset(self):
648
  """Refresh all dynamic lists and training state"""
 
1109
  ],
1110
  outputs=[status_box, log_box]
1111
  ).success(
1112
+ fn=self.get_latest_status_message_logs_and_button_labels,
1113
+ outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1114
  )
1115
 
 
1116
  pause_resume_btn.click(
1117
  fn=self.handle_pause_resume,
1118
  outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1119
  )
1120
 
1121
  stop_btn.click(
1122
+ fn=self.handle_stop,
1123
+ outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
 
 
 
1124
  )
1125
 
1126
  def handle_global_stop():
 
1182
  timer = gr.Timer(value=1)
1183
  timer.tick(
1184
  fn=lambda: (
1185
+ self.get_latest_status_message_logs_and_button_labels()
1186
  ),
1187
  outputs=[
1188
  status_box,
1189
  log_box,
1190
+ start_btn,
1191
  stop_btn,
1192
  pause_resume_btn
1193
  ]
 
1203
  ]
1204
  )
1205
 
1206
+ timer = gr.Timer(value=6)
1207
  timer.tick(
1208
  fn=lambda: self.update_titles(),
1209
  outputs=[
config.py CHANGED
@@ -16,7 +16,8 @@ STAGING_PATH = STORAGE_PATH / "staging" # This is where files
16
  TRAINING_PATH = STORAGE_PATH / "training" # Folder containing the final training dataset
17
  TRAINING_VIDEOS_PATH = TRAINING_PATH / "videos" # Captioned clips ready for training
18
  MODEL_PATH = STORAGE_PATH / "model" # Model checkpoints and files
19
- OUTPUT_PATH = STORAGE_PATH / "output" # Training outputs and logs
 
20
 
21
  # On the production server we can afford to preload the big model
22
  PRELOAD_CAPTIONING_MODEL = parse_bool_env(os.environ.get('PRELOAD_CAPTIONING_MODEL'))
@@ -66,6 +67,8 @@ TRAINING_HEIGHT = 512 # 32 * 16
66
  # right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
67
 
68
  TRAINING_BUCKETS = [
 
 
69
  (8 * 2 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 16 + 1
70
  (8 * 4 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 32 + 1
71
  (8 * 6 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 48 + 1
 
16
  TRAINING_PATH = STORAGE_PATH / "training" # Folder containing the final training dataset
17
  TRAINING_VIDEOS_PATH = TRAINING_PATH / "videos" # Captioned clips ready for training
18
  MODEL_PATH = STORAGE_PATH / "model" # Model checkpoints and files
19
+ OUTPUT_PATH = STORAGE_PATH / "output" # Training outputs and logs
20
+ LOG_FILE_PATH = OUTPUT_PATH / "last_session.log"
21
 
22
  # On the production server we can afford to preload the big model
23
  PRELOAD_CAPTIONING_MODEL = parse_bool_env(os.environ.get('PRELOAD_CAPTIONING_MODEL'))
 
67
  # right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
68
 
69
  TRAINING_BUCKETS = [
70
+ (1, TRAINING_HEIGHT, TRAINING_WIDTH), # 1
71
+ (8 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 8 + 1
72
  (8 * 2 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 16 + 1
73
  (8 * 4 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 32 + 1
74
  (8 * 6 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 48 + 1
finetrainers/dataset.py CHANGED
@@ -266,9 +266,9 @@ class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset):
266
  def _preprocess_video(self, path: Path) -> torch.Tensor:
267
  video_reader = decord.VideoReader(uri=path.as_posix())
268
  video_num_frames = len(video_reader)
269
- print(f"ImageOrVideoDatasetWithResizing: self.resolution_buckets = ", self.resolution_buckets)
270
- print(f"ImageOrVideoDatasetWithResizing: self.max_num_frames = ", self.max_num_frames)
271
- print(f"ImageOrVideoDatasetWithResizing: video_num_frames = ", video_num_frames)
272
 
273
  video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
274
 
 
266
  def _preprocess_video(self, path: Path) -> torch.Tensor:
267
  video_reader = decord.VideoReader(uri=path.as_posix())
268
  video_num_frames = len(video_reader)
269
+ #print(f"ImageOrVideoDatasetWithResizing: self.resolution_buckets = ", self.resolution_buckets)
270
+ #print(f"ImageOrVideoDatasetWithResizing: self.max_num_frames = ", self.max_num_frames)
271
+ #print(f"ImageOrVideoDatasetWithResizing: video_num_frames = ", video_num_frames)
272
 
273
  video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
274
 
training_log_parser.py CHANGED
@@ -66,14 +66,18 @@ class TrainingLogParser:
66
  """Parse a single log line and update state"""
67
  try:
68
  # For debugging
69
- logger.info(f"Parsing line: {line[:100]}...")
70
 
71
  # Training step progress line example:
72
  # Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7]
 
 
 
 
73
  if "Training steps:" in line:
74
  # Set status to training if we see this
75
  self.state.status = "training"
76
- print("setting status to 'training'")
77
  if not self.state.start_time:
78
  self.state.start_time = datetime.now()
79
 
 
66
  """Parse a single log line and update state"""
67
  try:
68
  # For debugging
69
+ #logger.info(f"Parsing line: {line[:100]}...")
70
 
71
  # Training step progress line example:
72
  # Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7]
73
+
74
+ if ("Started training" in line) or (("Starting training" in line):
75
+ self.state.status = "training"
76
+
77
  if "Training steps:" in line:
78
  # Set status to training if we see this
79
  self.state.status = "training"
80
+ #print("setting status to 'training'")
81
  if not self.state.start_time:
82
  self.state.start_time = datetime.now()
83
 
training_service.py CHANGED
@@ -19,7 +19,7 @@ import select
19
  from typing import Any, Optional, Dict, List, Union, Tuple
20
 
21
  from huggingface_hub import upload_folder, create_repo
22
- from config import TrainingConfig, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
23
  from utils import make_archive, parse_training_log, is_image_file, is_video_file
24
  from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
25
 
@@ -29,7 +29,7 @@ logging.basicConfig(
29
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
30
  handlers=[
31
  logging.StreamHandler(sys.stdout),
32
- logging.FileHandler(str(OUTPUT_PATH / 'training_service.log'))
33
  ]
34
  )
35
  logger = logging.getLogger(__name__)
 
19
  from typing import Any, Optional, Dict, List, Union, Tuple
20
 
21
  from huggingface_hub import upload_folder, create_repo
22
+ from config import TrainingConfig, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
23
  from utils import make_archive, parse_training_log, is_image_file, is_video_file
24
  from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
25
 
 
29
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
30
  handlers=[
31
  logging.StreamHandler(sys.stdout),
32
+ logging.FileHandler(str(LOG_FILE_PATH))
33
  ]
34
  )
35
  logger = logging.getLogger(__name__)