import io from dataclasses import dataclass, field from typing import Any, Dict, List import torch import torch.distributed.checkpoint.stateful from .parallel import ParallelBackendType from .utils import get_device_info _device_type, _ = get_device_info() @dataclass class TrainState(torch.distributed.checkpoint.stateful.Stateful): step: int = 0 observed_data_samples: int = 0 observed_num_tokens: int = 0 global_avg_losses: List[float] = field(default_factory=list) global_max_losses: List[float] = field(default_factory=list) log_steps: List[int] = field(default_factory=list) def state_dict(self) -> Dict[str, Any]: # Only checkpoint global_avg_losses and global_max_losses per log frequency # to avoid sync overhead in every iteration. global_avg_losses_bytes = io.BytesIO() torch.save(self.global_avg_losses, global_avg_losses_bytes) global_max_losses_bytes = io.BytesIO() torch.save(self.global_max_losses, global_max_losses_bytes) log_steps_bytes = io.BytesIO() torch.save(self.log_steps, log_steps_bytes) return { "step": torch.tensor(self.step, dtype=torch.int32), "observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32), "observed_num_tokens": torch.tensor(self.observed_num_tokens, dtype=torch.int32), "global_avg_losses": global_avg_losses_bytes, "global_max_losses": global_max_losses_bytes, "log_steps": log_steps_bytes, } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: state_dict["global_avg_losses"].seek(0) state_dict["global_max_losses"].seek(0) state_dict["log_steps"].seek(0) self.step = state_dict["step"].item() self.observed_data_samples = state_dict["observed_data_samples"].item() self.observed_num_tokens = state_dict["observed_num_tokens"].item() self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False) self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False) self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) @dataclass class State: # Parallel state parallel_backend: ParallelBackendType = None # Training state train_state: TrainState = None num_trainable_parameters: int = 0 generator: torch.Generator = None # Hub state repo_id: str = None # Artifacts state output_dir: str = None