File size: 1,297 Bytes
80ebcb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import pickle
from typing import Any, Dict

import torch.distributed.checkpoint.stateful
import torchdata.stateful_dataloader

from ..logging import get_logger


logger = get_logger()


class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful):
    def __init__(
        self,
        rank: int,
        dataset: torch.utils.data.IterableDataset,
        batch_size: int = 1,
        num_workers: int = 0,
        collate_fn=None,
    ) -> None:
        super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn)

        self._dp_rank = rank
        self._rank_id = f"dp_rank_{rank}"

    def state_dict(self) -> Dict[str, Any]:
        # Store state only for dp rank to avoid replicating the same state across other dimensions
        return {self._rank_id: pickle.dumps(super().state_dict())}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        # State being empty is valid
        if not state_dict:
            return

        if self._rank_id not in state_dict:
            logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
            return

        super().load_state_dict(pickle.loads(state_dict[self._rank_id]))