VideoModelStudio / finetrainers /data /precomputation.py
jbilcke-hf's picture
jbilcke-hf HF Staff
upgrading finetrainers (and losing my extra code + improvements)
80ebcb3
raw
history blame
5.94 kB
import pathlib
from typing import Any, Callable, Dict, Iterable, Optional
import torch
from tqdm.auto import tqdm
from .. import utils
class DistributedDataPreprocessor:
def __init__(
self,
rank: int,
num_items: int,
processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
save_dir: str,
) -> None:
self._rank = rank
self._num_items = num_items
self._processor_fn = processor_fn
self._save_dir = pathlib.Path(save_dir)
self._cached_samples = []
self._preprocessed_iterator: "PreprocessedDataIterable" = None
self._save_dir.mkdir(parents=True, exist_ok=True)
subdirectories = [f for f in self._save_dir.iterdir() if f.is_dir()]
utils.delete_files(subdirectories)
def consume(
self,
data_type: str,
components: Dict[str, Any],
data_iterator,
generator: Optional[torch.Generator] = None,
cache_samples: bool = False,
use_cached_samples: bool = False,
drop_samples: bool = False,
) -> Iterable[Dict[str, Any]]:
if data_type not in self._processor_fn.keys():
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
if cache_samples:
if use_cached_samples:
raise ValueError("Cannot cache and use cached samples at the same time.")
if drop_samples:
raise ValueError("Cannot cache and drop samples at the same time.")
for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items):
if use_cached_samples:
item = self._cached_samples[i]
else:
item = next(data_iterator)
if cache_samples:
self._cached_samples.append(item)
item = self._processor_fn[data_type](**item, **components, generator=generator)
_save_item(self._rank, i, item, self._save_dir, data_type)
if drop_samples:
del self._cached_samples
self._cached_samples = []
utils.free_memory()
self._preprocessed_iterator = PreprocessedDataIterable(self._rank, self._save_dir, data_type)
return iter(self._preprocessed_iterator)
def consume_once(
self,
data_type: str,
components: Dict[str, Any],
data_iterator,
generator: Optional[torch.Generator] = None,
cache_samples: bool = False,
use_cached_samples: bool = False,
drop_samples: bool = False,
) -> Iterable[Dict[str, Any]]:
if data_type not in self._processor_fn.keys():
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
if cache_samples:
if use_cached_samples:
raise ValueError("Cannot cache and use cached samples at the same time.")
if drop_samples:
raise ValueError("Cannot cache and drop samples at the same time.")
for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items):
if use_cached_samples:
item = self._cached_samples[i]
else:
item = next(data_iterator)
if cache_samples:
self._cached_samples.append(item)
item = self._processor_fn[data_type](**item, **components, generator=generator)
_save_item(self._rank, i, item, self._save_dir, data_type)
if drop_samples:
del self._cached_samples
self._cached_samples = []
utils.free_memory()
self._preprocessed_iterator = PreprocessedOnceDataIterable(self._rank, self._save_dir, data_type)
return iter(self._preprocessed_iterator)
@property
def requires_data(self):
if self._preprocessed_iterator is None:
return True
return self._preprocessed_iterator.requires_data
class PreprocessedDataIterable:
def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
self._rank = rank
self._save_dir = pathlib.Path(save_dir)
self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt")))
self._data_type = data_type
self._requires_data = False
def __iter__(self) -> Iterable[Dict[str, Any]]:
for i in range(self._num_items):
if i == self._num_items - 1:
self._requires_data = True
yield _load_item(self._rank, i, self._save_dir, self._data_type)
def __len__(self) -> int:
return self._num_items
@property
def requires_data(self):
return self._requires_data
class PreprocessedOnceDataIterable:
def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
self._rank = rank
self._save_dir = pathlib.Path(save_dir)
self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt")))
self._data_type = data_type
self._requires_data = False
def __iter__(self) -> Iterable[Dict[str, Any]]:
index = 0
while True:
yield _load_item(self._rank, index, self._save_dir, self._data_type)
index = (index + 1) % self._num_items
def __len__(self) -> int:
return self._num_items
@property
def requires_data(self):
return self._requires_data
def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None:
filename = directory / f"{data_type}-{rank}-{index}.pt"
torch.save(item, filename.as_posix())
def _load_item(rank: int, index: int, directory: pathlib.Path, data_type: str) -> Dict[str, Any]:
filename = directory / f"{data_type}-{rank}-{index}.pt"
return torch.load(filename.as_posix(), weights_only=True)