Spaces:
Running
Running
import pathlib | |
from typing import Any, Callable, Dict, Iterable, Optional | |
import torch | |
from tqdm.auto import tqdm | |
from .. import utils | |
class DistributedDataPreprocessor: | |
def __init__( | |
self, | |
rank: int, | |
num_items: int, | |
processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]], | |
save_dir: str, | |
) -> None: | |
self._rank = rank | |
self._num_items = num_items | |
self._processor_fn = processor_fn | |
self._save_dir = pathlib.Path(save_dir) | |
self._cached_samples = [] | |
self._preprocessed_iterator: "PreprocessedDataIterable" = None | |
self._save_dir.mkdir(parents=True, exist_ok=True) | |
subdirectories = [f for f in self._save_dir.iterdir() if f.is_dir()] | |
utils.delete_files(subdirectories) | |
def consume( | |
self, | |
data_type: str, | |
components: Dict[str, Any], | |
data_iterator, | |
generator: Optional[torch.Generator] = None, | |
cache_samples: bool = False, | |
use_cached_samples: bool = False, | |
drop_samples: bool = False, | |
) -> Iterable[Dict[str, Any]]: | |
if data_type not in self._processor_fn.keys(): | |
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") | |
if cache_samples: | |
if use_cached_samples: | |
raise ValueError("Cannot cache and use cached samples at the same time.") | |
if drop_samples: | |
raise ValueError("Cannot cache and drop samples at the same time.") | |
for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items): | |
if use_cached_samples: | |
item = self._cached_samples[i] | |
else: | |
item = next(data_iterator) | |
if cache_samples: | |
self._cached_samples.append(item) | |
item = self._processor_fn[data_type](**item, **components, generator=generator) | |
_save_item(self._rank, i, item, self._save_dir, data_type) | |
if drop_samples: | |
del self._cached_samples | |
self._cached_samples = [] | |
utils.free_memory() | |
self._preprocessed_iterator = PreprocessedDataIterable(self._rank, self._save_dir, data_type) | |
return iter(self._preprocessed_iterator) | |
def consume_once( | |
self, | |
data_type: str, | |
components: Dict[str, Any], | |
data_iterator, | |
generator: Optional[torch.Generator] = None, | |
cache_samples: bool = False, | |
use_cached_samples: bool = False, | |
drop_samples: bool = False, | |
) -> Iterable[Dict[str, Any]]: | |
if data_type not in self._processor_fn.keys(): | |
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") | |
if cache_samples: | |
if use_cached_samples: | |
raise ValueError("Cannot cache and use cached samples at the same time.") | |
if drop_samples: | |
raise ValueError("Cannot cache and drop samples at the same time.") | |
for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items): | |
if use_cached_samples: | |
item = self._cached_samples[i] | |
else: | |
item = next(data_iterator) | |
if cache_samples: | |
self._cached_samples.append(item) | |
item = self._processor_fn[data_type](**item, **components, generator=generator) | |
_save_item(self._rank, i, item, self._save_dir, data_type) | |
if drop_samples: | |
del self._cached_samples | |
self._cached_samples = [] | |
utils.free_memory() | |
self._preprocessed_iterator = PreprocessedOnceDataIterable(self._rank, self._save_dir, data_type) | |
return iter(self._preprocessed_iterator) | |
def requires_data(self): | |
if self._preprocessed_iterator is None: | |
return True | |
return self._preprocessed_iterator.requires_data | |
class PreprocessedDataIterable: | |
def __init__(self, rank: int, save_dir: str, data_type: str) -> None: | |
self._rank = rank | |
self._save_dir = pathlib.Path(save_dir) | |
self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt"))) | |
self._data_type = data_type | |
self._requires_data = False | |
def __iter__(self) -> Iterable[Dict[str, Any]]: | |
for i in range(self._num_items): | |
if i == self._num_items - 1: | |
self._requires_data = True | |
yield _load_item(self._rank, i, self._save_dir, self._data_type) | |
def __len__(self) -> int: | |
return self._num_items | |
def requires_data(self): | |
return self._requires_data | |
class PreprocessedOnceDataIterable: | |
def __init__(self, rank: int, save_dir: str, data_type: str) -> None: | |
self._rank = rank | |
self._save_dir = pathlib.Path(save_dir) | |
self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt"))) | |
self._data_type = data_type | |
self._requires_data = False | |
def __iter__(self) -> Iterable[Dict[str, Any]]: | |
index = 0 | |
while True: | |
yield _load_item(self._rank, index, self._save_dir, self._data_type) | |
index = (index + 1) % self._num_items | |
def __len__(self) -> int: | |
return self._num_items | |
def requires_data(self): | |
return self._requires_data | |
def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None: | |
filename = directory / f"{data_type}-{rank}-{index}.pt" | |
torch.save(item, filename.as_posix()) | |
def _load_item(rank: int, index: int, directory: pathlib.Path, data_type: str) -> Dict[str, Any]: | |
filename = directory / f"{data_type}-{rank}-{index}.pt" | |
return torch.load(filename.as_posix(), weights_only=True) | |