Spaces:
Running
Running
import io | |
from dataclasses import dataclass, field | |
from typing import Any, Dict, List | |
import torch | |
import torch.distributed.checkpoint.stateful | |
from .parallel import ParallelBackendType | |
from .utils import get_device_info | |
_device_type, _ = get_device_info() | |
class TrainState(torch.distributed.checkpoint.stateful.Stateful): | |
step: int = 0 | |
observed_data_samples: int = 0 | |
observed_num_tokens: int = 0 | |
global_avg_losses: List[float] = field(default_factory=list) | |
global_max_losses: List[float] = field(default_factory=list) | |
log_steps: List[int] = field(default_factory=list) | |
def state_dict(self) -> Dict[str, Any]: | |
# Only checkpoint global_avg_losses and global_max_losses per log frequency | |
# to avoid sync overhead in every iteration. | |
global_avg_losses_bytes = io.BytesIO() | |
torch.save(self.global_avg_losses, global_avg_losses_bytes) | |
global_max_losses_bytes = io.BytesIO() | |
torch.save(self.global_max_losses, global_max_losses_bytes) | |
log_steps_bytes = io.BytesIO() | |
torch.save(self.log_steps, log_steps_bytes) | |
return { | |
"step": torch.tensor(self.step, dtype=torch.int32), | |
"observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32), | |
"observed_num_tokens": torch.tensor(self.observed_num_tokens, dtype=torch.int32), | |
"global_avg_losses": global_avg_losses_bytes, | |
"global_max_losses": global_max_losses_bytes, | |
"log_steps": log_steps_bytes, | |
} | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
state_dict["global_avg_losses"].seek(0) | |
state_dict["global_max_losses"].seek(0) | |
state_dict["log_steps"].seek(0) | |
self.step = state_dict["step"].item() | |
self.observed_data_samples = state_dict["observed_data_samples"].item() | |
self.observed_num_tokens = state_dict["observed_num_tokens"].item() | |
self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False) | |
self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False) | |
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) | |
class State: | |
# Parallel state | |
parallel_backend: ParallelBackendType = None | |
# Training state | |
train_state: TrainState = None | |
num_trainable_parameters: int = 0 | |
generator: torch.Generator = None | |
# Hub state | |
repo_id: str = None | |
# Artifacts state | |
output_dir: str = None | |