#pragma once #include #include #include #include namespace at { /** * Note [CUDA Graph-safe RNG states] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ * * Strategy: * ~~~~~~~~~ * (It helps to look at * cuda/detail/PhiloxCudaStateRaw.cuh and * cuda/detail/UnpackRaw.cuh * while you read this.) * * A CUDA graph containing multiple RNG ops behaves like a * single giant kernel from the perspective of ops external * to the graph. During graph capture, logic below records * the total of all offset increments that occur in the graphed * region, and records the final total as the offset for the * entire graph. * * When the graph reruns, the logic that reruns it * increments this device's CUDA generator's offset * by that total. * * Meanwhile, within the graph, at capture time, instead of * populating PhiloxCudaStates with the uint64_t offset pulled * directly from the global state, PhiloxCudaState instead * holds a pointer to one-element stream-local int64_t device tensor * holding an initial offset value, and a uint64_t holding an * intra-graph offset. (The intra-graph offset starts from zero * when capture begins.) In each consumer kernel, * at::cuda::philox::unpack computes the offset to use for this kernel * as intra-graph offset + *initial offset. * * When the graph reruns, the logic that reruns it first * fill_s the initial offset tensor with this device's * CUDA generator's current offset. * * The control flow above ensures graphed execution is bitwise * identical to eager execution as long as RNG ops are enqueued * from a single thread, even if RNG ops and graphs containing * RNG ops are enqueued and run simultaneously on multiple streams. * * Usage: * ~~~~~~ * PhiloxCudaState in this file, and unpack() in * cuda/CUDAGraphsUtils.cuh allow non-divergent use of * CUDAGeneratorImpl whether graph capture is underway or not. * * Each PhiloxCudaState instance should be used for one and only one * consumer kernel. * * Example (see e.g. native/cuda/Dropout.cu): * * #include * #include * * __global__ void kernel(..., PhiloxCudaState philox_args) { * auto seeds = at::cuda::philox::unpack(philox_args); * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; * curandStatePhilox4_32_10_t state; * curand_init(std::get<0>(seeds), // seed * idx, // per-thread subsequence * std::get<1>(seeds), // offset in subsequence * &state); * ... * } * * host_caller(...) { * PhiloxCudaState rng_engine_inputs; * { * // See Note [Acquire lock when using random generators] * std::lock_guard lock(gen->mutex_); * * // gen could be HostState or DevState here! No divergent code needed! * rng_engine_inputs = gen->philox_cuda_state(offset_increment); * } * kernel<<<...>>>(..., rng_engine_inputs); * } * */ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl { // Constructors CUDAGeneratorImpl(DeviceIndex device_index = -1); ~CUDAGeneratorImpl() override = default; // CUDAGeneratorImpl methods std::shared_ptr clone() const; void set_current_seed(uint64_t seed) override; uint64_t current_seed() const override; uint64_t seed() override; void set_state(const c10::TensorImpl& new_state) override; c10::intrusive_ptr get_state() const override; void set_philox_offset_per_thread(uint64_t offset); uint64_t philox_offset_per_thread() const; void capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph); uint64_t capture_epilogue(); PhiloxCudaState philox_cuda_state(uint64_t increment); // Temporarily accommodates call sites that use philox_engine_inputs. // Allows incremental refactor of call sites to use philox_cuda_state. std::pair philox_engine_inputs(uint64_t increment); static DeviceType device_type(); private: CUDAGeneratorImpl* clone_impl() const override; uint64_t seed_ = default_rng_seed_val; uint64_t philox_offset_per_thread_ = 0; int64_t* seed_extragraph_{}; int64_t* offset_extragraph_{}; uint32_t offset_intragraph_ = 0; bool graph_expects_this_gen_ = false; }; namespace cuda { namespace detail { TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator( DeviceIndex device_index = -1); TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1); } // namespace detail } // namespace cuda } // namespace at