jbilcke-hf HF staff commited on
Commit
947f205
·
1 Parent(s): 32b4f0f

ready for the demo

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +21 -13
  3. finetrainers_utils.py +7 -3
  4. training_log_parser.py +1 -1
  5. training_service.py +72 -13
.gitignore CHANGED
@@ -6,3 +6,4 @@ __pycache__
6
  *.mp4
7
  *.zip
8
  training_service.log
 
 
6
  *.mp4
7
  *.zip
8
  training_service.log
9
+ wandb/
app.py CHANGED
@@ -125,8 +125,6 @@ class VideoTrainerUI:
125
  # Stop captioning if running
126
  if self.captioner:
127
  self.captioner.stop_captioning()
128
- #self.captioner.close()
129
- #self.captioner = None
130
  status_messages["captioning"] = "Captioning stopped"
131
 
132
  # Stop scene detection if running
@@ -134,6 +132,12 @@ class VideoTrainerUI:
134
  self.splitter.processing = False
135
  status_messages["splitting"] = "Scene detection stopped"
136
 
 
 
 
 
 
 
137
  if LOG_FILE_PATH.exists():
138
  LOG_FILE_PATH.unlink()
139
 
@@ -153,6 +157,9 @@ class VideoTrainerUI:
153
  self._should_stop_captioning = True
154
  self.splitter.processing = False
155
 
 
 
 
156
  return {
157
  "status": "All processes stopped and data cleared",
158
  "details": status_messages
@@ -163,7 +170,7 @@ class VideoTrainerUI:
163
  "status": f"Error during cleanup: {str(e)}",
164
  "details": status_messages
165
  }
166
-
167
  def update_titles(self) -> Tuple[Any]:
168
  """Update all dynamic titles with current counts
169
 
@@ -664,20 +671,20 @@ class VideoTrainerUI:
664
  with gr.TabItem("1️⃣ Import", id="import_tab"):
665
 
666
  with gr.Row():
667
- gr.Markdown("## Optional: automated data cleaning")
668
 
669
  with gr.Row():
670
  enable_automatic_video_split = gr.Checkbox(
671
  label="Automatically split videos into smaller clips",
672
  info="Note: a clip is a single camera shot, usually a few seconds",
673
  value=True,
674
- visible=False
675
  )
676
  enable_automatic_content_captioning = gr.Checkbox(
677
  label="Automatically caption photos and videos",
678
  info="Note: this uses LlaVA and takes some extra time to load and process",
679
  value=False,
680
- visible=False,
681
  )
682
 
683
  with gr.Row():
@@ -889,13 +896,14 @@ class VideoTrainerUI:
889
  interactive=False,
890
  lines=4
891
  )
892
- log_box = gr.TextArea(
893
- label="Training Logs",
894
- interactive=False,
895
- lines=10,
896
- max_lines=40,
897
- autoscroll=True
898
- )
 
899
 
900
  with gr.TabItem("5️⃣ Manage"):
901
 
 
125
  # Stop captioning if running
126
  if self.captioner:
127
  self.captioner.stop_captioning()
 
 
128
  status_messages["captioning"] = "Captioning stopped"
129
 
130
  # Stop scene detection if running
 
132
  self.splitter.processing = False
133
  status_messages["splitting"] = "Scene detection stopped"
134
 
135
+ # Properly close logging before clearing log file
136
+ if self.trainer.file_handler:
137
+ self.trainer.file_handler.close()
138
+ logger.removeHandler(self.trainer.file_handler)
139
+ self.trainer.file_handler = None
140
+
141
  if LOG_FILE_PATH.exists():
142
  LOG_FILE_PATH.unlink()
143
 
 
157
  self._should_stop_captioning = True
158
  self.splitter.processing = False
159
 
160
+ # Recreate logging setup
161
+ self.trainer.setup_logging()
162
+
163
  return {
164
  "status": "All processes stopped and data cleared",
165
  "details": status_messages
 
170
  "status": f"Error during cleanup: {str(e)}",
171
  "details": status_messages
172
  }
173
+
174
  def update_titles(self) -> Tuple[Any]:
175
  """Update all dynamic titles with current counts
176
 
 
671
  with gr.TabItem("1️⃣ Import", id="import_tab"):
672
 
673
  with gr.Row():
674
+ gr.Markdown("## Automatic splitting and captioning")
675
 
676
  with gr.Row():
677
  enable_automatic_video_split = gr.Checkbox(
678
  label="Automatically split videos into smaller clips",
679
  info="Note: a clip is a single camera shot, usually a few seconds",
680
  value=True,
681
+ visible=True
682
  )
683
  enable_automatic_content_captioning = gr.Checkbox(
684
  label="Automatically caption photos and videos",
685
  info="Note: this uses LlaVA and takes some extra time to load and process",
686
  value=False,
687
+ visible=True,
688
  )
689
 
690
  with gr.Row():
 
896
  interactive=False,
897
  lines=4
898
  )
899
+ with gr.Accordion("See training logs"):
900
+ log_box = gr.TextArea(
901
+ label="Finetrainers output (see HF Space logs for more details)",
902
+ interactive=False,
903
+ lines=40,
904
+ max_lines=200,
905
+ autoscroll=True
906
+ )
907
 
908
  with gr.TabItem("5️⃣ Manage"):
909
 
finetrainers_utils.py CHANGED
@@ -115,9 +115,13 @@ def copy_files_to_training_dir(prompt_prefix: str) -> int:
115
 
116
  # make sure we only copy over VALID pairs
117
  if caption:
118
- target_caption_path.write_text(caption)
119
- shutil.copy2(file_path, target_file_path)
120
- nb_copied_pairs += 1
 
 
 
 
121
 
122
  prepare_finetrainers_dataset()
123
 
 
115
 
116
  # make sure we only copy over VALID pairs
117
  if caption:
118
+ try:
119
+ target_caption_path.write_text(caption)
120
+ shutil.copy2(file_path, target_file_path)
121
+ nb_copied_pairs += 1
122
+ except Exception as e:
123
+ print(f"failed to copy one of the pairs: {e}")
124
+ pass
125
 
126
  prepare_finetrainers_dataset()
127
 
training_log_parser.py CHANGED
@@ -71,7 +71,7 @@ class TrainingLogParser:
71
  # Training step progress line example:
72
  # Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7]
73
 
74
- if ("Started training" in line) or (("Starting training" in line):
75
  self.state.status = "training"
76
 
77
  if "Training steps:" in line:
 
71
  # Training step progress line example:
72
  # Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7]
73
 
74
+ if ("Started training" in line) or ("Starting training" in line):
75
  self.state.status = "training"
76
 
77
  if "Training steps:" in line:
training_service.py CHANGED
@@ -23,15 +23,6 @@ from config import TrainingConfig, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, STORAGE_
23
  from utils import make_archive, parse_training_log, is_image_file, is_video_file
24
  from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
25
 
26
- # Configure logging
27
- logging.basicConfig(
28
- level=logging.INFO,
29
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
30
- handlers=[
31
- logging.StreamHandler(sys.stdout),
32
- logging.FileHandler(str(LOG_FILE_PATH))
33
- ]
34
- )
35
  logger = logging.getLogger(__name__)
36
 
37
  class TrainingService:
@@ -41,8 +32,69 @@ class TrainingService:
41
  self.status_file = OUTPUT_PATH / "status.json"
42
  self.pid_file = OUTPUT_PATH / "training.pid"
43
  self.log_file = OUTPUT_PATH / "training.log"
 
 
 
 
44
  logger.info("Training service initialized")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def save_session(self, params: Dict) -> None:
47
  """Save training session parameters"""
48
  session_data = {
@@ -73,7 +125,7 @@ class TrainingService:
73
  try:
74
  with open(self.status_file, 'r') as f:
75
  status = json.load(f)
76
- print("status found in the json:", status)
77
 
78
  # Check if process is actually running
79
  if self.pid_file.exists():
@@ -81,7 +133,7 @@ class TrainingService:
81
  pid = int(f.read().strip())
82
  if not psutil.pid_exists(pid):
83
  # Process died unexpectedly
84
- if status['status'] == 'running':
85
  status['status'] = 'error'
86
  status['message'] = 'Training process terminated unexpectedly'
87
  self.append_log("Training process terminated unexpectedly")
@@ -302,7 +354,7 @@ class TrainingService:
302
  # Update initial training status
303
  total_steps = num_epochs * (max(1, video_count) // batch_size)
304
  self.save_status(
305
- state='running',
306
  epoch=0,
307
  step=0,
308
  total_steps=total_steps,
@@ -389,7 +441,7 @@ class TrainingService:
389
 
390
  if psutil.pid_exists(pid):
391
  os.kill(pid, signal.SIGUSR2) # Signal to resume
392
- self.save_status(state='running', message='Training resumed')
393
  self.append_log("Training resumed")
394
 
395
  return "Training resumed", self.get_logs()
@@ -437,6 +489,13 @@ class TrainingService:
437
  'timestamp': datetime.now().isoformat(),
438
  **kwargs
439
  }
 
 
 
 
 
 
 
440
  with open(self.status_file, 'w') as f:
441
  json.dump(status, f, indent=2)
442
 
 
23
  from utils import make_archive, parse_training_log, is_image_file, is_video_file
24
  from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
25
 
 
 
 
 
 
 
 
 
 
26
  logger = logging.getLogger(__name__)
27
 
28
  class TrainingService:
 
32
  self.status_file = OUTPUT_PATH / "status.json"
33
  self.pid_file = OUTPUT_PATH / "training.pid"
34
  self.log_file = OUTPUT_PATH / "training.log"
35
+
36
+ self.file_handler = None
37
+ self.setup_logging()
38
+
39
  logger.info("Training service initialized")
40
 
41
+ def setup_logging(self):
42
+ """Set up logging with proper handler management"""
43
+ global logger
44
+ logger = logging.getLogger(__name__)
45
+ logger.setLevel(logging.INFO)
46
+
47
+ # Remove any existing handlers to avoid duplicates
48
+ logger.handlers.clear()
49
+
50
+ # Add stdout handler
51
+ stdout_handler = logging.StreamHandler(sys.stdout)
52
+ stdout_handler.setFormatter(logging.Formatter(
53
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
54
+ ))
55
+ logger.addHandler(stdout_handler)
56
+
57
+ # Add file handler if log file is accessible
58
+ try:
59
+ # Close existing file handler if it exists
60
+ if self.file_handler:
61
+ self.file_handler.close()
62
+ logger.removeHandler(self.file_handler)
63
+
64
+ self.file_handler = logging.FileHandler(str(LOG_FILE_PATH))
65
+ self.file_handler.setFormatter(logging.Formatter(
66
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
67
+ ))
68
+ logger.addHandler(self.file_handler)
69
+ except Exception as e:
70
+ logger.warning(f"Could not set up log file: {e}")
71
+
72
+ def clear_logs(self) -> None:
73
+ """Clear log file with proper handler cleanup"""
74
+ try:
75
+ # Remove and close the file handler
76
+ if self.file_handler:
77
+ logger.removeHandler(self.file_handler)
78
+ self.file_handler.close()
79
+ self.file_handler = None
80
+
81
+ # Delete the file if it exists
82
+ if LOG_FILE_PATH.exists():
83
+ LOG_FILE_PATH.unlink()
84
+
85
+ # Recreate logging setup
86
+ self.setup_logging()
87
+ self.append_log("Log file cleared and recreated")
88
+
89
+ except Exception as e:
90
+ logger.error(f"Error clearing logs: {e}")
91
+ raise
92
+
93
+ def __del__(self):
94
+ """Cleanup when the service is destroyed"""
95
+ if self.file_handler:
96
+ self.file_handler.close()
97
+
98
  def save_session(self, params: Dict) -> None:
99
  """Save training session parameters"""
100
  session_data = {
 
125
  try:
126
  with open(self.status_file, 'r') as f:
127
  status = json.load(f)
128
+ #print("status found in the json:", status)
129
 
130
  # Check if process is actually running
131
  if self.pid_file.exists():
 
133
  pid = int(f.read().strip())
134
  if not psutil.pid_exists(pid):
135
  # Process died unexpectedly
136
+ if status['status'] == 'training':
137
  status['status'] = 'error'
138
  status['message'] = 'Training process terminated unexpectedly'
139
  self.append_log("Training process terminated unexpectedly")
 
354
  # Update initial training status
355
  total_steps = num_epochs * (max(1, video_count) // batch_size)
356
  self.save_status(
357
+ state='training',
358
  epoch=0,
359
  step=0,
360
  total_steps=total_steps,
 
441
 
442
  if psutil.pid_exists(pid):
443
  os.kill(pid, signal.SIGUSR2) # Signal to resume
444
+ self.save_status(state='training', message='Training resumed')
445
  self.append_log("Training resumed")
446
 
447
  return "Training resumed", self.get_logs()
 
489
  'timestamp': datetime.now().isoformat(),
490
  **kwargs
491
  }
492
+ if state === "Training started" or state == "initializing":
493
+ gr.Info("Initializing model and dataset..")
494
+ elif state == "training":
495
+ gr.Info("Training started!")
496
+ elif state == "completed":
497
+ gr.Info("Training completed!")
498
+
499
  with open(self.status_file, 'w') as f:
500
  json.dump(status, f, indent=2)
501