VideoModelStudio / training_log_parser.py
jbilcke-hf's picture
jbilcke-hf HF Staff
initial commit log 🪵🦫
91fb4ef
raw
history blame
8.21 kB
import re
import logging
from dataclasses import dataclass
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
@dataclass
class TrainingState:
"""Represents the current state of training"""
status: str = "idle" # idle, initializing, training, completed, error, stopped
current_step: int = 0
total_steps: int = 0
current_epoch: int = 0
total_epochs: int = 0
step_loss: float = 0.0
learning_rate: float = 0.0
grad_norm: float = 0.0
memory_allocated: float = 0.0
memory_reserved: float = 0.0
start_time: Optional[datetime] = None
last_step_time: Optional[datetime] = None
estimated_remaining: Optional[timedelta] = None
error_message: Optional[str] = None
initialization_stage: str = ""
download_progress: float = 0.0
def calculate_progress(self) -> float:
"""Calculate overall progress as percentage"""
if self.total_steps == 0:
return 0.0
return (self.current_step / self.total_steps) * 100
def to_dict(self) -> Dict[str, Any]:
"""Convert state to dictionary for UI updates"""
elapsed = str(datetime.now() - self.start_time) if self.start_time else "0:00:00"
remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..."
return {
"status": self.status,
"progress": f"{self.calculate_progress():.1f}%",
"current_step": self.current_step,
"total_steps": self.total_steps,
"current_epoch": self.current_epoch,
"total_epochs": self.total_epochs,
"step_loss": f"{self.step_loss:.4f}",
"learning_rate": f"{self.learning_rate:.2e}",
"grad_norm": f"{self.grad_norm:.4f}",
"memory": f"{self.memory_allocated:.1f}GB allocated, {self.memory_reserved:.1f}GB reserved",
"elapsed": elapsed,
"remaining": remaining,
"initialization_stage": self.initialization_stage,
"error_message": self.error_message,
"download_progress": self.download_progress
}
class TrainingLogParser:
"""Parser for training logs with state management"""
def __init__(self):
self.state = TrainingState()
self._last_update_time = None
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
"""Parse a single log line and update state"""
try:
# For debugging
logger.info(f"Parsing line: {line[:100]}...")
# Training step progress line example:
# Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7]
if "Training steps:" in line:
# Set status to training if we see this
self.state.status = "training"
if not self.state.start_time:
self.state.start_time = datetime.now()
# Extract step numbers
steps_match = re.search(r"(\d+)/(\d+)", line)
if steps_match:
self.state.current_step = int(steps_match.group(1))
self.state.total_steps = int(steps_match.group(2))
# Extract metrics
for pattern, attr in [
(r"step_loss=([0-9.e-]+)", "step_loss"),
(r"lr=([0-9.e-]+)", "learning_rate"),
(r"grad_norm=([0-9.e-]+)", "grad_norm")
]:
match = re.search(pattern, line)
if match:
setattr(self.state, attr, float(match.group(1)))
# Calculate time estimates based on total elapsed time
now = datetime.now()
if self.state.start_time and self.state.current_step > 0:
# Calculate elapsed time and average time per step
elapsed_seconds = (now - self.state.start_time).total_seconds()
avg_time_per_step = elapsed_seconds / self.state.current_step
# Calculate remaining time
remaining_steps = self.state.total_steps - self.state.current_step
estimated_remaining_seconds = avg_time_per_step * remaining_steps
# Format as days, hours, minutes, seconds
days = int(estimated_remaining_seconds // (24 * 3600))
hours = int((estimated_remaining_seconds % (24 * 3600)) // 3600)
minutes = int((estimated_remaining_seconds % 3600) // 60)
seconds = int(estimated_remaining_seconds % 60)
# Create formatted timedelta
if days > 0:
formatted_time = f"{days}d {hours}h {minutes}m {seconds}s"
elif hours > 0:
formatted_time = f"{hours}h {minutes}m {seconds}s"
elif minutes > 0:
formatted_time = f"{minutes}m {seconds}s"
else:
formatted_time = f"{seconds}s"
self.state.estimated_remaining = formatted_time
self.state.last_step_time = now
logger.info(f"Updated training state: step={self.state.current_step}/{self.state.total_steps}, loss={self.state.step_loss}")
return self.state.to_dict()
# Epoch information
# there is an issue with how epoch is reported because we display:
# Progress: 96.9%, Step: 872/900, Epoch: 12/50
# we should probably just show the steps
epoch_match = re.search(r"Starting epoch \((\d+)/(\d+)\)", line)
if epoch_match:
self.state.current_epoch = int(epoch_match.group(1))
self.state.total_epochs = int(epoch_match.group(2))
logger.info(f"Updated epoch: {self.state.current_epoch}/{self.state.total_epochs}")
return self.state.to_dict()
# Initialization stages
if "Initializing" in line:
self.state.status = "initializing"
self.state.initialization_stage = line.split("Initializing")[1].strip()
logger.info(f"Initialization stage: {self.state.initialization_stage}")
return self.state.to_dict()
# Memory usage
if "memory_allocated" in line:
mem_match = re.search(r'"memory_allocated":\s*([0-9.]+)', line)
if mem_match:
self.state.memory_allocated = float(mem_match.group(1))
reserved_match = re.search(r'"memory_reserved":\s*([0-9.]+)', line)
if reserved_match:
self.state.memory_reserved = float(reserved_match.group(1))
logger.info(f"Updated memory: allocated={self.state.memory_allocated}GB, reserved={self.state.memory_reserved}GB")
return self.state.to_dict()
# Completion states
if "Training completed successfully" in line:
self.state.status = "completed"
logger.info("Training completed")
return self.state.to_dict()
if any(x in line for x in ["Training process stopped", "Training stopped"]):
self.state.status = "stopped"
logger.info("Training stopped")
return self.state.to_dict()
if "Error during training:" in line:
self.state.status = "error"
self.state.error_message = line.split("Error during training:")[1].strip()
logger.info(f"Training error: {self.state.error_message}")
return self.state.to_dict()
except Exception as e:
logger.error(f"Error parsing line: {str(e)}")
return None
def reset(self):
"""Reset parser state"""
self.state = TrainingState()
self._last_update_time = None