File size: 5,589 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# mypy: allow-untyped-defs
import inspect
import warnings
from typing import Any, List, Optional, Set
from typing_extensions import deprecated
import torch
from torch.utils.data.datapipes.iter.sharding import (
_ShardingIterDataPipe,
SHARDING_PRIORITIES,
)
from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
__all__ = [
"apply_random_seed",
"apply_sharding",
"apply_shuffle_seed",
"apply_shuffle_settings",
"get_all_graph_pipes",
]
def get_all_graph_pipes(graph: DataPipeGraph) -> List[DataPipe]:
return _get_all_graph_pipes_helper(graph, set())
def _get_all_graph_pipes_helper(graph: DataPipeGraph, id_cache: Set[int]) -> List[DataPipe]:
results: List[DataPipe] = []
for dp_id, (datapipe, sub_graph) in graph.items():
if dp_id in id_cache:
continue
id_cache.add(dp_id)
results.append(datapipe)
results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache))
return results
def _is_sharding_datapipe(datapipe: DataPipe) -> bool:
if isinstance(datapipe, _ShardingIterDataPipe):
return True
if hasattr(datapipe, "apply_sharding") and inspect.ismethod(datapipe.apply_sharding):
return True
return False
def apply_sharding(datapipe: DataPipe,
num_of_instances: int,
instance_id: int,
sharding_group=SHARDING_PRIORITIES.DEFAULT) -> DataPipe:
r"""
Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``.
RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch.
"""
graph = traverse_dps(datapipe)
def _helper(graph, prev_applied=None):
for (dp, sub_graph) in graph.values():
applied = None
if _is_sharding_datapipe(dp):
if prev_applied is not None:
raise RuntimeError("Sharding twice on a single pipeline is likely unintended and will cause data loss. "
f"Sharding already applied to {prev_applied} while trying to apply to {dp}")
# For BC, only provide sharding_group if accepted
sig = inspect.signature(dp.apply_sharding)
if len(sig.parameters) < 3:
dp.apply_sharding(num_of_instances, instance_id)
else:
dp.apply_sharding(num_of_instances, instance_id, sharding_group=sharding_group)
applied = dp
if applied is None:
applied = prev_applied
_helper(sub_graph, applied)
_helper(graph)
return datapipe
def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
if not hasattr(datapipe, "set_shuffle") or not hasattr(datapipe, "set_seed"):
return False
if not inspect.ismethod(datapipe.set_shuffle) or not inspect.ismethod(datapipe.set_seed):
return False
return True
def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool] = None) -> DataPipe:
r"""
Traverse the graph of ``DataPipes`` to find and set shuffle attribute.
Apply the method to each `DataPipe` that has APIs of ``set_shuffle``
and ``set_seed``.
Args:
datapipe: DataPipe that needs to set shuffle attribute
shuffle: Shuffle option (default: ``None`` and no-op to the graph)
"""
if shuffle is None:
return datapipe
graph = traverse_dps(datapipe)
all_pipes = get_all_graph_pipes(graph)
shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)]
if not shufflers and shuffle:
warnings.warn(
"`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
"Be aware that the default buffer size might not be sufficient for your task."
)
datapipe = datapipe.shuffle()
shufflers = [datapipe, ] # type: ignore[list-item]
for shuffler in shufflers:
shuffler.set_shuffle(shuffle)
return datapipe
@deprecated(
"`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. "
"Please use `apply_random_seed` instead.",
category=FutureWarning,
)
def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
return apply_random_seed(datapipe, rng)
def _is_random_datapipe(datapipe: DataPipe) -> bool:
if hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed):
return True
return False
def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
r"""
Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``.
Then set the random seed based on the provided RNG to those ``DataPipe``.
Args:
datapipe: DataPipe that needs to set randomness
rng: Random number generator to generate random seeds
"""
graph = traverse_dps(datapipe)
all_pipes = get_all_graph_pipes(graph)
# Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once.
# And, `id` is used in case of unhashable DataPipe
cache = set()
random_datapipes = []
for pipe in all_pipes:
if id(pipe) in cache:
continue
if _is_random_datapipe(pipe):
random_datapipes.append(pipe)
cache.add(id(pipe))
for pipe in random_datapipes:
random_seed = int(torch.empty((), dtype=torch.int64).random_(generator=rng).item())
pipe.set_seed(random_seed)
return datapipe
|