Spaces:
Running
Running
File size: 1,297 Bytes
80ebcb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import pickle
from typing import Any, Dict
import torch.distributed.checkpoint.stateful
import torchdata.stateful_dataloader
from ..logging import get_logger
logger = get_logger()
class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful):
def __init__(
self,
rank: int,
dataset: torch.utils.data.IterableDataset,
batch_size: int = 1,
num_workers: int = 0,
collate_fn=None,
) -> None:
super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn)
self._dp_rank = rank
self._rank_id = f"dp_rank_{rank}"
def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {self._rank_id: pickle.dumps(super().state_dict())}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid
if not state_dict:
return
if self._rank_id not in state_dict:
logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
return
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
|