jbilcke-hf HF staff commited on
Commit
38cfbff
Β·
1 Parent(s): 29d6f3c

working to improve log reporting

Browse files
vms/services/trainer.py CHANGED
@@ -834,7 +834,6 @@ class TrainingService:
834
  params = last_session.get('params', {})
835
 
836
  # Map internal model type back to display name for UI
837
- # This is the key fix for the "ltx_video" vs "LTX-Video (LoRA)" mismatch
838
  model_type_internal = params.get('model_type')
839
  model_type_display = model_type_internal
840
 
 
834
  params = last_session.get('params', {})
835
 
836
  # Map internal model type back to display name for UI
 
837
  model_type_internal = params.get('model_type')
838
  model_type_display = model_type_internal
839
 
vms/tabs/train_tab.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Train tab for Video Model Studio UI
3
  """
4
 
5
  import gradio as gr
@@ -126,7 +126,7 @@ class TrainTab(BaseTab):
126
  visible=False
127
  )
128
 
129
- # Add delete checkpoints button - THIS IS THE KEY FIX
130
  self.components["delete_checkpoints_btn"] = gr.Button(
131
  "Delete All Checkpoints",
132
  variant="stop",
@@ -140,6 +140,15 @@ class TrainTab(BaseTab):
140
  interactive=False,
141
  lines=4
142
  )
 
 
 
 
 
 
 
 
 
143
  with gr.Accordion("See training logs"):
144
  self.components["log_box"] = gr.TextArea(
145
  label="Finetrainers output (see HF Space logs for more details)",
@@ -288,7 +297,8 @@ class TrainTab(BaseTab):
288
  self.components["log_box"],
289
  self.components["start_btn"],
290
  self.components["stop_btn"],
291
- self.components["pause_resume_btn"]
 
292
  ]
293
  )
294
 
@@ -299,7 +309,8 @@ class TrainTab(BaseTab):
299
  self.components["log_box"],
300
  self.components["start_btn"],
301
  self.components["stop_btn"],
302
- self.components["pause_resume_btn"]
 
303
  ]
304
  )
305
 
@@ -310,7 +321,8 @@ class TrainTab(BaseTab):
310
  self.components["log_box"],
311
  self.components["start_btn"],
312
  self.components["stop_btn"],
313
- self.components["pause_resume_btn"]
 
314
  ]
315
  )
316
 
@@ -325,7 +337,8 @@ class TrainTab(BaseTab):
325
  self.components["log_box"],
326
  self.components["start_btn"],
327
  self.components["stop_btn"],
328
- self.components["delete_checkpoints_btn"]
 
329
  ]
330
  )
331
 
@@ -555,6 +568,12 @@ class TrainTab(BaseTab):
555
 
556
  updates["status_box"] = "\n".join(status_text)
557
 
 
 
 
 
 
 
558
  # Update button states
559
  updates["start_btn"] = gr.Button(
560
  "Start training",
@@ -638,6 +657,10 @@ class TrainTab(BaseTab):
638
  elif "stopped" in state["message"].lower():
639
  state["status"] = "stopped"
640
 
 
 
 
 
641
  return (state["status"], state["message"], logs)
642
 
643
  def get_latest_status_message_logs_and_button_labels(self) -> Tuple:
@@ -649,8 +672,13 @@ class TrainTab(BaseTab):
649
 
650
  button_updates = self.update_training_buttons(status, has_checkpoints).values()
651
 
652
- # Return in order expected by timer
653
- return (message, logs, *button_updates)
 
 
 
 
 
654
 
655
  def update_training_buttons(self, status: str, has_checkpoints: bool = None) -> Dict:
656
  """Update training control buttons based on state"""
 
1
  """
2
+ Train tab for Video Model Studio UI with improved task progress display
3
  """
4
 
5
  import gradio as gr
 
126
  visible=False
127
  )
128
 
129
+ # Add delete checkpoints button
130
  self.components["delete_checkpoints_btn"] = gr.Button(
131
  "Delete All Checkpoints",
132
  variant="stop",
 
140
  interactive=False,
141
  lines=4
142
  )
143
+
144
+ # Add new component for current task progress
145
+ self.components["current_task_box"] = gr.Textbox(
146
+ label="Current Task Progress",
147
+ interactive=False,
148
+ lines=3,
149
+ elem_id="current_task_display"
150
+ )
151
+
152
  with gr.Accordion("See training logs"):
153
  self.components["log_box"] = gr.TextArea(
154
  label="Finetrainers output (see HF Space logs for more details)",
 
297
  self.components["log_box"],
298
  self.components["start_btn"],
299
  self.components["stop_btn"],
300
+ self.components["pause_resume_btn"],
301
+ self.components["current_task_box"] # Include new component
302
  ]
303
  )
304
 
 
309
  self.components["log_box"],
310
  self.components["start_btn"],
311
  self.components["stop_btn"],
312
+ self.components["pause_resume_btn"],
313
+ self.components["current_task_box"] # Include new component
314
  ]
315
  )
316
 
 
321
  self.components["log_box"],
322
  self.components["start_btn"],
323
  self.components["stop_btn"],
324
+ self.components["pause_resume_btn"],
325
+ self.components["current_task_box"] # Include new component
326
  ]
327
  )
328
 
 
337
  self.components["log_box"],
338
  self.components["start_btn"],
339
  self.components["stop_btn"],
340
+ self.components["delete_checkpoints_btn"],
341
+ self.components["current_task_box"] # Include new component
342
  ]
343
  )
344
 
 
568
 
569
  updates["status_box"] = "\n".join(status_text)
570
 
571
+ # Add current task information to the dedicated box
572
+ if training_state.get("current_task"):
573
+ updates["current_task_box"] = training_state["current_task"]
574
+ else:
575
+ updates["current_task_box"] = "No active task" if training_state["status"] != "training" else "Waiting for task information..."
576
+
577
  # Update button states
578
  updates["start_btn"] = gr.Button(
579
  "Start training",
 
657
  elif "stopped" in state["message"].lower():
658
  state["status"] = "stopped"
659
 
660
+ # Add the current task info if available
661
+ if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
662
+ state["current_task"] = self.app.log_parser.get_current_task_display()
663
+
664
  return (state["status"], state["message"], logs)
665
 
666
  def get_latest_status_message_logs_and_button_labels(self) -> Tuple:
 
672
 
673
  button_updates = self.update_training_buttons(status, has_checkpoints).values()
674
 
675
+ # Get current task if available
676
+ current_task = ""
677
+ if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
678
+ current_task = self.app.log_parser.get_current_task_display()
679
+
680
+ # Return in order expected by timer (added current_task)
681
+ return (message, logs, *button_updates, current_task)
682
 
683
  def update_training_buttons(self, status: str, has_checkpoints: bool = None) -> Dict:
684
  """Update training control buttons based on state"""
vms/ui/video_trainer_ui.py CHANGED
@@ -89,13 +89,14 @@ class VideoTrainerUI:
89
  self.tabs["train_tab"].components["pause_resume_btn"],
90
  self.tabs["train_tab"].components["training_preset"],
91
  self.tabs["train_tab"].components["model_type"],
92
- self.tabs["train_tab"].components["training_type"], # Add the new training_type component to outputs
93
  self.tabs["train_tab"].components["lora_rank"],
94
  self.tabs["train_tab"].components["lora_alpha"],
95
  self.tabs["train_tab"].components["num_epochs"],
96
  self.tabs["train_tab"].components["batch_size"],
97
  self.tabs["train_tab"].components["learning_rate"],
98
- self.tabs["train_tab"].components["save_iterations"]
 
99
  ]
100
  )
101
 
@@ -114,6 +115,10 @@ class VideoTrainerUI:
114
  self.tabs["train_tab"].components["stop_btn"]
115
  ]
116
 
 
 
 
 
117
  # Add delete_checkpoints_btn only if it exists
118
  if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
119
  outputs.append(self.tabs["train_tab"].components["delete_checkpoints_btn"])
@@ -237,6 +242,11 @@ class VideoTrainerUI:
237
  learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
238
  save_iterations_val = int(ui_state.get("save_iterations", 500))
239
 
 
 
 
 
 
240
  # Return all values in the exact order expected by outputs
241
  return (
242
  video_list,
@@ -252,7 +262,8 @@ class VideoTrainerUI:
252
  num_epochs_val,
253
  batch_size_val,
254
  learning_rate_val,
255
- save_iterations_val
 
256
  )
257
 
258
  def initialize_ui_from_state(self):
@@ -293,7 +304,7 @@ class VideoTrainerUI:
293
  ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
294
 
295
  return ui_state
296
-
297
  # Add this new method to get initial button states:
298
  def get_initial_button_states(self):
299
  """Get the initial states for training buttons based on recovery status"""
 
89
  self.tabs["train_tab"].components["pause_resume_btn"],
90
  self.tabs["train_tab"].components["training_preset"],
91
  self.tabs["train_tab"].components["model_type"],
92
+ self.tabs["train_tab"].components["training_type"],
93
  self.tabs["train_tab"].components["lora_rank"],
94
  self.tabs["train_tab"].components["lora_alpha"],
95
  self.tabs["train_tab"].components["num_epochs"],
96
  self.tabs["train_tab"].components["batch_size"],
97
  self.tabs["train_tab"].components["learning_rate"],
98
+ self.tabs["train_tab"].components["save_iterations"],
99
+ self.tabs["train_tab"].components["current_task_box"] # Add new component
100
  ]
101
  )
102
 
 
115
  self.tabs["train_tab"].components["stop_btn"]
116
  ]
117
 
118
+ # Add current_task_box component
119
+ if "current_task_box" in self.tabs["train_tab"].components:
120
+ outputs.append(self.tabs["train_tab"].components["current_task_box"])
121
+
122
  # Add delete_checkpoints_btn only if it exists
123
  if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
124
  outputs.append(self.tabs["train_tab"].components["delete_checkpoints_btn"])
 
242
  learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
243
  save_iterations_val = int(ui_state.get("save_iterations", 500))
244
 
245
+ # Initial current task value
246
+ current_task_val = ""
247
+ if hasattr(self, 'log_parser') and self.log_parser:
248
+ current_task_val = self.log_parser.get_current_task_display()
249
+
250
  # Return all values in the exact order expected by outputs
251
  return (
252
  video_list,
 
262
  num_epochs_val,
263
  batch_size_val,
264
  learning_rate_val,
265
+ save_iterations_val,
266
+ current_task_val # Add current task value
267
  )
268
 
269
  def initialize_ui_from_state(self):
 
304
  ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
305
 
306
  return ui_state
307
+
308
  # Add this new method to get initial button states:
309
  def get_initial_button_states(self):
310
  """Get the initial states for training buttons based on recovery status"""
vms/utils/training_log_parser.py CHANGED
@@ -1,7 +1,7 @@
1
  import re
2
  import logging
3
  from dataclasses import dataclass
4
- from typing import Optional, Dict, Any
5
  from datetime import datetime, timedelta
6
 
7
  logger = logging.getLogger(__name__)
@@ -25,6 +25,22 @@ class TrainingState:
25
  error_message: Optional[str] = None
26
  initialization_stage: str = ""
27
  download_progress: float = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def calculate_progress(self) -> float:
30
  """Calculate overall progress as percentage"""
@@ -44,7 +60,7 @@ class TrainingState:
44
  # Use precomputed remaining time from logs if available
45
  remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..."
46
 
47
- return {
48
  "status": self.status,
49
  "progress": f"{self.calculate_progress():.1f}%",
50
  "current_step": self.current_step,
@@ -61,6 +77,96 @@ class TrainingState:
61
  "error_message": self.error_message,
62
  "download_progress": self.download_progress
63
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  class TrainingLogParser:
66
  """Parser for training logs with state management"""
@@ -68,12 +174,30 @@ class TrainingLogParser:
68
  def __init__(self):
69
  self.state = TrainingState()
70
  self._last_update_time = None
 
 
71
 
 
 
 
 
 
 
 
 
 
72
  def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
73
  """Parse a single log line and update state"""
74
  try:
75
- # For debugging
76
- #logger.info(f"Parsing line: {line[:100]}...")
 
 
 
 
 
 
 
77
 
78
  # Training step progress line example:
79
  # Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7]
@@ -157,16 +281,16 @@ class TrainingLogParser:
157
 
158
  # Completion states
159
  if "Training completed successfully" in line:
160
- self.status = "completed"
161
  # Store final elapsed time
162
- self.last_step_time = datetime.now()
163
  logger.info("Training completed")
164
  return self.state.to_dict()
165
 
166
  if any(x in line for x in ["Training process stopped", "Training stopped"]):
167
- self.status = "stopped"
168
  # Store final elapsed time
169
- self.last_step_time = datetime.now()
170
  logger.info("Training stopped")
171
  return self.state.to_dict()
172
 
@@ -179,9 +303,4 @@ class TrainingLogParser:
179
  except Exception as e:
180
  logger.error(f"Error parsing line: {str(e)}")
181
 
182
- return None
183
-
184
- def reset(self):
185
- """Reset parser state"""
186
- self.state = TrainingState()
187
- self._last_update_time = None
 
1
  import re
2
  import logging
3
  from dataclasses import dataclass
4
+ from typing import Optional, Dict, Any, List
5
  from datetime import datetime, timedelta
6
 
7
  logger = logging.getLogger(__name__)
 
25
  error_message: Optional[str] = None
26
  initialization_stage: str = ""
27
  download_progress: float = 0.0
28
+
29
+ # New fields for current task tracking
30
+ current_task: str = ""
31
+ current_task_progress: str = ""
32
+ task_progress_percentage: float = 0.0
33
+ task_items_processed: int = 0
34
+ task_total_items: int = 0
35
+ task_time_remaining: str = ""
36
+ task_speed: str = ""
37
+
38
+ # Store recent progress lines for task display
39
+ recent_progress_lines: List[str] = None
40
+
41
+ def __post_init__(self):
42
+ if self.recent_progress_lines is None:
43
+ self.recent_progress_lines = []
44
 
45
  def calculate_progress(self) -> float:
46
  """Calculate overall progress as percentage"""
 
60
  # Use precomputed remaining time from logs if available
61
  remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..."
62
 
63
+ result = {
64
  "status": self.status,
65
  "progress": f"{self.calculate_progress():.1f}%",
66
  "current_step": self.current_step,
 
77
  "error_message": self.error_message,
78
  "download_progress": self.download_progress
79
  }
80
+
81
+ # Add current task information
82
+ result["current_task"] = self.get_task_display()
83
+
84
+ return result
85
+
86
+ def get_task_display(self) -> str:
87
+ """Generate a formatted display of the current task"""
88
+ if not self.recent_progress_lines:
89
+ if self.status == "training":
90
+ return "Training in progress..."
91
+ return ""
92
+
93
+ # Get the most recent progress line
94
+ latest_line = self.recent_progress_lines[-1]
95
+
96
+ # For downloading shards or loading checkpoint shards
97
+ if "Downloading shards" in latest_line or "Loading checkpoint shards" in latest_line:
98
+ # Extract just the progress bar part
99
+ match = re.search(r'(\d+%\|[β–β–Žβ–β–Œβ–‹β–Šβ–‰β–ˆ\s]+\|)', latest_line)
100
+ if match:
101
+ progress_bar = match.group(1)
102
+
103
+ # Extract the remaining information
104
+ time_match = re.search(r'\[(\d+:\d+<\d+:\d+,\s+[\d.]+s/it)', latest_line)
105
+ time_info = time_match.group(1) if time_match else ""
106
+
107
+ task_type = "Downloading shards" if "Downloading shards" in latest_line else "Loading checkpoint shards"
108
+
109
+ return f"{task_type}:\n{progress_bar}\n{time_info}"
110
+
111
+ # For "Rank 0" progress (typically training steps)
112
+ elif "Rank 0:" in latest_line:
113
+ match = re.search(r'Rank 0:\s+(\d+%\|[β–β–Žβ–β–Œβ–‹β–Šβ–‰β–ˆ\s]+\|)', latest_line)
114
+ if match:
115
+ progress_bar = match.group(1)
116
+
117
+ # Extract step information
118
+ step_match = re.search(r'\|\s+(\d+/\d+)', latest_line)
119
+ step_info = step_match.group(1) if step_match else ""
120
+
121
+ # Extract time information
122
+ time_match = re.search(r'\[(\d+:\d+<\d+:\d+,\s+[\d.]+s/it)', latest_line)
123
+ time_info = time_match.group(1) if time_match else ""
124
+
125
+ return f"Training iteration:\n{progress_bar} {step_info}\n{time_info}"
126
+
127
+ # For Filling buffer progress
128
+ elif "Filling buffer" in latest_line:
129
+ match = re.search(r'(\d+%\|[β–β–Žβ–β–Œβ–‹β–Šβ–‰β–ˆ\s]+\|)', latest_line)
130
+ if match:
131
+ progress_bar = match.group(1)
132
+
133
+ # Extract step information
134
+ step_match = re.search(r'\|\s+(\d+/\d+)', latest_line)
135
+ step_info = step_match.group(1) if step_match else ""
136
+
137
+ # Extract time information
138
+ time_match = re.search(r'\[(\d+:\d+<\d+:\d+,\s+[\d.]+s/it)', latest_line)
139
+ time_info = time_match.group(1) if time_match else ""
140
+
141
+ return f"Filling buffer from data iterator:\n{progress_bar} {step_info}\n{time_info}"
142
+
143
+ # For other progress lines
144
+ elif "%" in latest_line and "|" in latest_line:
145
+ # Generic progress bar pattern
146
+ match = re.search(r'(\d+%\|[β–β–Žβ–β–Œβ–‹β–Šβ–‰β–ˆ\s]+\|)', latest_line)
147
+ if match:
148
+ progress_bar = match.group(1)
149
+
150
+ # Try to extract step information
151
+ step_match = re.search(r'\|\s+(\d+/\d+)', latest_line)
152
+ step_info = step_match.group(1) if step_match else ""
153
+
154
+ # Try to extract time information
155
+ time_match = re.search(r'\[(\d+:\d+<\d+:\d+,\s+[\d.]+s/it)', latest_line)
156
+ time_info = time_match.group(1) if time_match else ""
157
+
158
+ task_prefix = "Processing:"
159
+
160
+ # Try to determine task type
161
+ if "Training" in latest_line:
162
+ task_prefix = "Training:"
163
+ elif "Precomputing" in latest_line:
164
+ task_prefix = "Precomputing:"
165
+
166
+ return f"{task_prefix}\n{progress_bar} {step_info}\n{time_info}"
167
+
168
+ # If we couldn't parse it properly, just return the line
169
+ return latest_line.strip()
170
 
171
  class TrainingLogParser:
172
  """Parser for training logs with state management"""
 
174
  def __init__(self):
175
  self.state = TrainingState()
176
  self._last_update_time = None
177
+ # Maximum number of recent progress lines to store
178
+ self.max_recent_lines = 5
179
 
180
+ def reset(self):
181
+ """Reset parser state"""
182
+ self.state = TrainingState()
183
+ self._last_update_time = None
184
+
185
+ def get_current_task_display(self) -> str:
186
+ """Get the formatted current task display"""
187
+ return self.state.get_task_display()
188
+
189
  def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
190
  """Parse a single log line and update state"""
191
  try:
192
+ # Check if this is a progress line
193
+ if any(pattern in line for pattern in ["Downloading shards:", "Loading checkpoint shards:", "Rank 0:", "Filling buffer", "|"]) and "%" in line:
194
+ # Add to recent progress lines, maintaining order and max length
195
+ self.state.recent_progress_lines.append(line)
196
+ if len(self.state.recent_progress_lines) > self.max_recent_lines:
197
+ self.state.recent_progress_lines.pop(0)
198
+
199
+ # Return updated state
200
+ return self.state.to_dict()
201
 
202
  # Training step progress line example:
203
  # Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7]
 
281
 
282
  # Completion states
283
  if "Training completed successfully" in line:
284
+ self.state.status = "completed"
285
  # Store final elapsed time
286
+ self.state.last_step_time = datetime.now()
287
  logger.info("Training completed")
288
  return self.state.to_dict()
289
 
290
  if any(x in line for x in ["Training process stopped", "Training stopped"]):
291
+ self.state.status = "stopped"
292
  # Store final elapsed time
293
+ self.state.last_step_time = datetime.now()
294
  logger.info("Training stopped")
295
  return self.state.to_dict()
296
 
 
303
  except Exception as e:
304
  logger.error(f"Error parsing line: {str(e)}")
305
 
306
+ return None