import pathlib from typing import Any, Callable, Dict, Iterable, Optional import torch from tqdm.auto import tqdm from .. import utils class DistributedDataPreprocessor: def __init__( self, rank: int, num_items: int, processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]], save_dir: str, ) -> None: self._rank = rank self._num_items = num_items self._processor_fn = processor_fn self._save_dir = pathlib.Path(save_dir) self._cached_samples = [] self._preprocessed_iterator: "PreprocessedDataIterable" = None self._save_dir.mkdir(parents=True, exist_ok=True) subdirectories = [f for f in self._save_dir.iterdir() if f.is_dir()] utils.delete_files(subdirectories) def consume( self, data_type: str, components: Dict[str, Any], data_iterator, generator: Optional[torch.Generator] = None, cache_samples: bool = False, use_cached_samples: bool = False, drop_samples: bool = False, ) -> Iterable[Dict[str, Any]]: if data_type not in self._processor_fn.keys(): raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") if cache_samples: if use_cached_samples: raise ValueError("Cannot cache and use cached samples at the same time.") if drop_samples: raise ValueError("Cannot cache and drop samples at the same time.") for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items): if use_cached_samples: item = self._cached_samples[i] else: item = next(data_iterator) if cache_samples: self._cached_samples.append(item) item = self._processor_fn[data_type](**item, **components, generator=generator) _save_item(self._rank, i, item, self._save_dir, data_type) if drop_samples: del self._cached_samples self._cached_samples = [] utils.free_memory() self._preprocessed_iterator = PreprocessedDataIterable(self._rank, self._save_dir, data_type) return iter(self._preprocessed_iterator) def consume_once( self, data_type: str, components: Dict[str, Any], data_iterator, generator: Optional[torch.Generator] = None, cache_samples: bool = False, use_cached_samples: bool = False, drop_samples: bool = False, ) -> Iterable[Dict[str, Any]]: if data_type not in self._processor_fn.keys(): raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") if cache_samples: if use_cached_samples: raise ValueError("Cannot cache and use cached samples at the same time.") if drop_samples: raise ValueError("Cannot cache and drop samples at the same time.") for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items): if use_cached_samples: item = self._cached_samples[i] else: item = next(data_iterator) if cache_samples: self._cached_samples.append(item) item = self._processor_fn[data_type](**item, **components, generator=generator) _save_item(self._rank, i, item, self._save_dir, data_type) if drop_samples: del self._cached_samples self._cached_samples = [] utils.free_memory() self._preprocessed_iterator = PreprocessedOnceDataIterable(self._rank, self._save_dir, data_type) return iter(self._preprocessed_iterator) @property def requires_data(self): if self._preprocessed_iterator is None: return True return self._preprocessed_iterator.requires_data class PreprocessedDataIterable: def __init__(self, rank: int, save_dir: str, data_type: str) -> None: self._rank = rank self._save_dir = pathlib.Path(save_dir) self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt"))) self._data_type = data_type self._requires_data = False def __iter__(self) -> Iterable[Dict[str, Any]]: for i in range(self._num_items): if i == self._num_items - 1: self._requires_data = True yield _load_item(self._rank, i, self._save_dir, self._data_type) def __len__(self) -> int: return self._num_items @property def requires_data(self): return self._requires_data class PreprocessedOnceDataIterable: def __init__(self, rank: int, save_dir: str, data_type: str) -> None: self._rank = rank self._save_dir = pathlib.Path(save_dir) self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt"))) self._data_type = data_type self._requires_data = False def __iter__(self) -> Iterable[Dict[str, Any]]: index = 0 while True: yield _load_item(self._rank, index, self._save_dir, self._data_type) index = (index + 1) % self._num_items def __len__(self) -> int: return self._num_items @property def requires_data(self): return self._requires_data def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None: filename = directory / f"{data_type}-{rank}-{index}.pt" torch.save(item, filename.as_posix()) def _load_item(rank: int, index: int, directory: pathlib.Path, data_type: str) -> Dict[str, Any]: filename = directory / f"{data_type}-{rank}-{index}.pt" return torch.load(filename.as_posix(), weights_only=True)