jbilcke-hf's picture
jbilcke-hf HF Staff
upgrading finetrainers (and losing my extra code + improvements)
80ebcb3
raw
history blame
2.56 kB
import io
from dataclasses import dataclass, field
from typing import Any, Dict, List
import torch
import torch.distributed.checkpoint.stateful
from .parallel import ParallelBackendType
from .utils import get_device_info
_device_type, _ = get_device_info()
@dataclass
class TrainState(torch.distributed.checkpoint.stateful.Stateful):
step: int = 0
observed_data_samples: int = 0
observed_num_tokens: int = 0
global_avg_losses: List[float] = field(default_factory=list)
global_max_losses: List[float] = field(default_factory=list)
log_steps: List[int] = field(default_factory=list)
def state_dict(self) -> Dict[str, Any]:
# Only checkpoint global_avg_losses and global_max_losses per log frequency
# to avoid sync overhead in every iteration.
global_avg_losses_bytes = io.BytesIO()
torch.save(self.global_avg_losses, global_avg_losses_bytes)
global_max_losses_bytes = io.BytesIO()
torch.save(self.global_max_losses, global_max_losses_bytes)
log_steps_bytes = io.BytesIO()
torch.save(self.log_steps, log_steps_bytes)
return {
"step": torch.tensor(self.step, dtype=torch.int32),
"observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32),
"observed_num_tokens": torch.tensor(self.observed_num_tokens, dtype=torch.int32),
"global_avg_losses": global_avg_losses_bytes,
"global_max_losses": global_max_losses_bytes,
"log_steps": log_steps_bytes,
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
state_dict["global_avg_losses"].seek(0)
state_dict["global_max_losses"].seek(0)
state_dict["log_steps"].seek(0)
self.step = state_dict["step"].item()
self.observed_data_samples = state_dict["observed_data_samples"].item()
self.observed_num_tokens = state_dict["observed_num_tokens"].item()
self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False)
self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False)
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
@dataclass
class State:
# Parallel state
parallel_backend: ParallelBackendType = None
# Training state
train_state: TrainState = None
num_trainable_parameters: int = 0
generator: torch.Generator = None
# Hub state
repo_id: str = None
# Artifacts state
output_dir: str = None