Spaces:
Running
Running
File size: 10,028 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
from typing import Optional, Tuple
import torch
import torch.utils._pytree as pytree
from torch import _prims
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for
from torch._prims_common.wrappers import backwards_not_supported
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.types import _device, _dtype
rngprim_namespace = "rngprims"
rngprim = torch.library.Library(rngprim_namespace, "DEF")
rngprim_impl = torch.library.Library(
rngprim_namespace, "IMPL", "CompositeExplicitAutograd"
)
rngprim_autograd_impl = torch.library.Library(rngprim_namespace, "IMPL", "Autograd")
rngprim_meta_impl = torch.library.Library(rngprim_namespace, "IMPL", "Meta")
def throw_on_non_cuda(device):
raise RuntimeError(
f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
"not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
)
def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
rngprim.define(schema)
rngprim_impl.impl(name, impl_aten)
rngprim_meta_impl.impl(name, impl_meta)
prim_packet = getattr(torch._ops.ops.rngprims, name)
prim = prim_packet.default
if tags:
prim._tags = tags
rngprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = doc
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
p.schema = schema
p.impl_aten = impl_aten
p.prim_meta_impl = impl_meta
# Philox rand offsets could be shared in future with other philox ops, so
# keeping these functions in global scope.
def philox_rand_offset_meta(
shape: torch.Size,
):
return _prims.TensorLike(torch.tensor(0, dtype=torch.int64))
def philox_rand_offset(
shape: torch.Size,
):
# For impl, look at the function calc_execution_policy in the file
# aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at
# commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6
numel_scalar = 1
for dim_size in shape:
numel_scalar *= dim_size
numel = torch.scalar_tensor(numel_scalar, dtype=torch.int64)
block_size = 256
unroll = 4
curand4_engine_calls = 4
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
grid_size = (numel + block_size - 1) // block_size
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
offset = (
(numel - 1) // (block_size * grid_size * unroll) + 1
) * curand4_engine_calls
return offset
def register_philox_rand():
name = "philox_rand"
schema = "philox_rand(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" # noqa: B950
def _philox_rand_meta(
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[Tuple[int, ...]],
device: _device,
dtype: _dtype,
):
# stride arg will be useful for distributed usecase. Currently, its unused.
assert stride is None
stride = make_contiguous_strides_for(shape)
random_values = _prims.TensorMeta(
shape=shape, strides=stride, dtype=dtype, device=device
)
offset = philox_rand_offset_meta(shape)
return (random_values, offset)
def _philox_rand(
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[Tuple[int, ...]],
device: _device,
dtype: _dtype,
):
# stride arg will be useful for distributed usecase. Currently, its unused.
assert stride is None
if device.type == "cpu":
devices = []
else:
devices = [device]
if device.type != "cuda":
raise throw_on_non_cuda(device)
with torch.random.fork_rng(devices):
CUDARngStateHelper.set_torch_state_tensor(seed, offset)
random_values = torch.rand(shape, device=device, dtype=dtype)
return random_values, philox_rand_offset(shape)
register_rng_prim(
name=name,
schema=schema,
impl_aten=_philox_rand,
impl_meta=_philox_rand_meta,
doc="Philox based stateless rand operator",
tags=(torch.Tag.nondeterministic_seeded,),
)
def get_device(args, kwargs):
if kwargs.get("device"):
device = kwargs.get("device")
if isinstance(device, str):
device = torch.device(device)
return device.type
devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)}
if any(dev == "cuda" for dev in devices):
return "cuda"
elif any(dev == "cpu" for dev in devices):
return "cpu"
return None
def register_run_and_save_rng_state_op():
run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_and_save_rng_state, deferred_error=True)
)
@run_and_save_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(op, *args, **kwargs):
return torch.cuda.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.CPU)
def impl_cpu(op, *args, **kwargs):
return torch.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(op, *args, **kwargs):
impl_map = {"cuda": impl_cuda, "cpu": impl_cpu}
device = get_device(args, kwargs)
assert device in impl_map, f"Backend not supported for {device}"
impl = impl_map[device]
return impl(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(mode, op, *args, **kwargs):
# Check device to call the right impl
with mode:
return impl_backend_select(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
if mode.enable_tracing:
out = impl_backend_select(op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
else:
return run_and_save_rng_state(op, *args, **kwargs)
return run_and_save_rng_state
def register_run_with_rng_state_op():
run_with_rng_state = HigherOrderOperator("run_with_rng_state")
run_with_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_with_rng_state, deferred_error=True)
)
@run_with_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(rng_state, op, *args, **kwargs):
current_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state.cpu())
out = op(*args, **kwargs)
torch.cuda.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(DispatchKey.CPU)
def impl_cpu(rng_state, op, *args, **kwargs):
current_state = torch.get_rng_state()
torch.set_rng_state(rng_state)
out = op(*args, **kwargs)
torch.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
if mode.enable_tracing:
with disable_proxy_modes_tracing():
out = run_with_rng_state(rng_state, op, *args, **kwargs)
proxy_args = pytree.tree_map(
mode.tracer.unwrap_proxy, (rng_state, op, *args)
)
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
else:
return run_with_rng_state(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(rng_state, op, *args, **kwargs):
impl_map = {"cuda": impl_cuda, "cpu": impl_cpu}
device = get_device(args, kwargs)
assert device in impl_map, f"Backend not supported for {device}"
impl = impl_map[device]
return impl(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs):
# Skip setting the set_rng_state as it does not work well with fake tensors.
# And it does not matter for the fake tensor mode.
with mode:
return op(*args, **kwargs)
return run_with_rng_state
run_and_save_rng_state = register_run_and_save_rng_state_op()
run_with_rng_state = register_run_with_rng_state_op()
def register_rng_prims():
register_philox_rand()
|