Spaces:
Runtime error
Runtime error
import re | |
import logging | |
from dataclasses import dataclass | |
from typing import Optional, Dict, Any | |
from datetime import datetime, timedelta | |
logger = logging.getLogger(__name__) | |
class TrainingState: | |
"""Represents the current state of training""" | |
status: str = "idle" # idle, initializing, training, completed, error, stopped | |
current_step: int = 0 | |
total_steps: int = 0 | |
current_epoch: int = 0 | |
total_epochs: int = 0 | |
step_loss: float = 0.0 | |
learning_rate: float = 0.0 | |
grad_norm: float = 0.0 | |
memory_allocated: float = 0.0 | |
memory_reserved: float = 0.0 | |
start_time: Optional[datetime] = None | |
last_step_time: Optional[datetime] = None | |
estimated_remaining: Optional[timedelta] = None | |
error_message: Optional[str] = None | |
initialization_stage: str = "" | |
download_progress: float = 0.0 | |
def calculate_progress(self) -> float: | |
"""Calculate overall progress as percentage""" | |
if self.total_steps == 0: | |
return 0.0 | |
return (self.current_step / self.total_steps) * 100 | |
def to_dict(self) -> Dict[str, Any]: | |
"""Convert state to dictionary for UI updates""" | |
elapsed = str(datetime.now() - self.start_time) if self.start_time else "0:00:00" | |
remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..." | |
return { | |
"status": self.status, | |
"progress": f"{self.calculate_progress():.1f}%", | |
"current_step": self.current_step, | |
"total_steps": self.total_steps, | |
"current_epoch": self.current_epoch, | |
"total_epochs": self.total_epochs, | |
"step_loss": f"{self.step_loss:.4f}", | |
"learning_rate": f"{self.learning_rate:.2e}", | |
"grad_norm": f"{self.grad_norm:.4f}", | |
"memory": f"{self.memory_allocated:.1f}GB allocated, {self.memory_reserved:.1f}GB reserved", | |
"elapsed": elapsed, | |
"remaining": remaining, | |
"initialization_stage": self.initialization_stage, | |
"error_message": self.error_message, | |
"download_progress": self.download_progress | |
} | |
class TrainingLogParser: | |
"""Parser for training logs with state management""" | |
def __init__(self): | |
self.state = TrainingState() | |
self._last_update_time = None | |
def parse_line(self, line: str) -> Optional[Dict[str, Any]]: | |
"""Parse a single log line and update state""" | |
try: | |
# For debugging | |
#logger.info(f"Parsing line: {line[:100]}...") | |
# Training step progress line example: | |
# Training steps: 1%|▏ | 1/70 [00:14<16:11, 14.08s/it, grad_norm=0.00789, step_loss=0.555, lr=3e-7] | |
if ("Started training" in line) or ("Starting training" in line): | |
self.state.status = "training" | |
if "Training steps:" in line: | |
# Set status to training if we see this | |
self.state.status = "training" | |
#print("setting status to 'training'") | |
if not self.state.start_time: | |
self.state.start_time = datetime.now() | |
# Extract step numbers | |
steps_match = re.search(r"(\d+)/(\d+)", line) | |
if steps_match: | |
self.state.current_step = int(steps_match.group(1)) | |
self.state.total_steps = int(steps_match.group(2)) | |
# Extract metrics | |
for pattern, attr in [ | |
(r"step_loss=([0-9.e-]+)", "step_loss"), | |
(r"lr=([0-9.e-]+)", "learning_rate"), | |
(r"grad_norm=([0-9.e-]+)", "grad_norm") | |
]: | |
match = re.search(pattern, line) | |
if match: | |
setattr(self.state, attr, float(match.group(1))) | |
# Calculate time estimates based on total elapsed time | |
now = datetime.now() | |
if self.state.start_time and self.state.current_step > 0: | |
# Calculate elapsed time and average time per step | |
elapsed_seconds = (now - self.state.start_time).total_seconds() | |
avg_time_per_step = elapsed_seconds / self.state.current_step | |
# Calculate remaining time | |
remaining_steps = self.state.total_steps - self.state.current_step | |
estimated_remaining_seconds = avg_time_per_step * remaining_steps | |
# Format as days, hours, minutes, seconds | |
days = int(estimated_remaining_seconds // (24 * 3600)) | |
hours = int((estimated_remaining_seconds % (24 * 3600)) // 3600) | |
minutes = int((estimated_remaining_seconds % 3600) // 60) | |
seconds = int(estimated_remaining_seconds % 60) | |
# Create formatted timedelta | |
if days > 0: | |
formatted_time = f"{days}d {hours}h {minutes}m {seconds}s" | |
elif hours > 0: | |
formatted_time = f"{hours}h {minutes}m {seconds}s" | |
elif minutes > 0: | |
formatted_time = f"{minutes}m {seconds}s" | |
else: | |
formatted_time = f"{seconds}s" | |
self.state.estimated_remaining = formatted_time | |
self.state.last_step_time = now | |
logger.info(f"Updated training state: step={self.state.current_step}/{self.state.total_steps}, loss={self.state.step_loss}") | |
return self.state.to_dict() | |
# Epoch information | |
# there is an issue with how epoch is reported because we display: | |
# Progress: 96.9%, Step: 872/900, Epoch: 12/50 | |
# we should probably just show the steps | |
epoch_match = re.search(r"Starting epoch \((\d+)/(\d+)\)", line) | |
if epoch_match: | |
self.state.current_epoch = int(epoch_match.group(1)) | |
self.state.total_epochs = int(epoch_match.group(2)) | |
logger.info(f"Updated epoch: {self.state.current_epoch}/{self.state.total_epochs}") | |
return self.state.to_dict() | |
# Initialization stages | |
if "Initializing" in line: | |
self.state.status = "initializing" | |
self.state.initialization_stage = line.split("Initializing")[1].strip() | |
logger.info(f"Initialization stage: {self.state.initialization_stage}") | |
return self.state.to_dict() | |
# Memory usage | |
if "memory_allocated" in line: | |
mem_match = re.search(r'"memory_allocated":\s*([0-9.]+)', line) | |
if mem_match: | |
self.state.memory_allocated = float(mem_match.group(1)) | |
reserved_match = re.search(r'"memory_reserved":\s*([0-9.]+)', line) | |
if reserved_match: | |
self.state.memory_reserved = float(reserved_match.group(1)) | |
logger.info(f"Updated memory: allocated={self.state.memory_allocated}GB, reserved={self.state.memory_reserved}GB") | |
return self.state.to_dict() | |
# Completion states | |
if "Training completed successfully" in line: | |
self.state.status = "completed" | |
logger.info("Training completed") | |
return self.state.to_dict() | |
if any(x in line for x in ["Training process stopped", "Training stopped"]): | |
self.state.status = "stopped" | |
logger.info("Training stopped") | |
return self.state.to_dict() | |
if "Error during training:" in line: | |
self.state.status = "error" | |
self.state.error_message = line.split("Error during training:")[1].strip() | |
logger.info(f"Training error: {self.state.error_message}") | |
return self.state.to_dict() | |
except Exception as e: | |
logger.error(f"Error parsing line: {str(e)}") | |
return None | |
def reset(self): | |
"""Reset parser state""" | |
self.state = TrainingState() | |
self._last_update_time = None |