#pragma once #include #include #include #include #include #include #include #include #include #include #include namespace at { namespace native { namespace templates { // ==================================================== Random ======================================================== // The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`. // The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t). // This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance: // // auto actual = torch::empty({3, 3}, torch::half); // actual.random_(0, 65504); // // If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504 // and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to` // moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to // the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous // available number for torch::half dtype. template int64_t update_from(int64_t from) { static_assert( std::is_floating_point::value || std::is_same::value || std::is_same::value, "scalar_t must be floating-point type"); const auto from_plus_1 = static_cast(static_cast(from + 1)); if (from_plus_1 < from) { int64_t from_ = std::abs(from + 1); int n = 0; while (from_ >>= 1) ++n; // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) from = from_plus_1 + (1LL << (n - std::numeric_limits::digits + 1)); } return from; } template int64_t update_to(int64_t to) { static_assert( std::is_floating_point::value || std::is_same::value || std::is_same::value, "scalar_t must be floating-point type"); const auto to_minus_1 = static_cast(static_cast(to - 1)); if (to_minus_1 >= to) { int64_t to_ = std::abs(to - 1); int n = 0; while (to_ >>= 1) ++n; // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) to = to_minus_1 - (1LL << (n - std::numeric_limits::digits + 1)); } return to; } template class random_kernel, typename RNG> at::Tensor& random_impl(at::Tensor& self, c10::optional generator) { auto iter = at::TensorIterator::borrowing_nullary_op(self); random_kernel()(iter, generator); return self; } #define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \ TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \ #define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \ if (var < -(1LL << digits) || var > (1LL << digits)) { \ TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \ "Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \ "This warning will become an error in version 1.7 release, please fix the code in advance"); \ } static void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) { const auto scalar_type = typeMetaToScalarType(dtype); if (isFloatingType(scalar_type)) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] { const auto min = static_cast(std::numeric_limits::lowest()); const auto max = static_cast(std::numeric_limits::max()); CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype); CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype); constexpr auto digits = std::numeric_limits::digits; WARN_OUT_OF_BOUNDS(from, "from", digits, dtype); WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype); }); } else if (isIntegralType(scalar_type, /*includeBool=*/true)) { AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, scalar_type, "check_random_integral_bounds", [&]() { const auto min = static_cast(std::numeric_limits::lowest()); const auto max = static_cast(std::numeric_limits::max()); CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype); CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype); }); } else { TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types"); } } template class random_from_to_kernel, typename RNG> at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, c10::optional to_opt, c10::optional generator) { uint64_t range = 0; auto iter = at::TensorIterator::borrowing_nullary_op(self); if (to_opt.has_value()) { // [from, to) int64_t to = *to_opt; TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to); if (isFloatingType(iter.dtype())) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] { from = update_from(from); to = update_to(to); TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to); }); } check_from_to_in_range(from, to - 1, self.dtype()); range = static_cast(to) - static_cast(from); random_from_to_kernel()(iter, range, from, generator); } else if (from != std::numeric_limits::lowest()) { // [from, std::numeric_limits::max()] int64_t to_inc = 0; if (isFloatingType(iter.dtype())) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] { constexpr int64_t scalar_t_max = static_cast(1) << std::numeric_limits::digits; to_inc = scalar_t_max > std::numeric_limits::max() ? std::numeric_limits::max() : static_cast(scalar_t_max); from = update_from(from); TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc); }); } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) { AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "random_from_to_range_calc", [&] { if (std::is_same::value) { to_inc = static_cast(true); } else { to_inc = static_cast(std::numeric_limits::max()); } }); } else { TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types"); } check_from_to_in_range(from, to_inc, self.dtype()); range = static_cast(to_inc) - static_cast(from) + 1; random_from_to_kernel()(iter, range, from, generator); } else { // [std::numeric_limits::lowest(), std::numeric_limits::max()] // range = 2^64 random_from_to_kernel()(iter, generator); } return self; } // ==================================================== Normal ======================================================== #define CHECK_NORMAL_TENSOR_STD(std) \ do { \ TORCH_CHECK( \ !std.is_complex(), \ "normal expects standard deviation to be non-complex"); \ TORCH_CHECK( \ std.numel() == 0 || std.is_meta() || std.min().ge(0).item(), \ "normal expects all elements of std >= 0.0"); \ } while (0) #define CHECK_NORMAL_STD(std) \ TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std); template class normal_kernel, typename RNG> Tensor& normal_impl_(Tensor& self, double mean, double std, c10::optional gen) { CHECK_NORMAL_STD(std); if (self.is_complex()) { auto float_tensor = at::view_as_real(self); // variance for normal distribution of the real and imaginary values // is half of the input variance normal_kernel()(float_tensor, mean, std/(std::sqrt(2)), gen); } else { normal_kernel()(self, mean, std, gen); } return self; } template class normal_kernel, typename RNG> Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, c10::optional gen) { CHECK_NORMAL_STD(std); auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous); auto shape = at::infer_size(mean.sizes(), std_tensor.sizes()); at::native::resize_output(output, shape); normal_impl_(output, 0, std, gen); output.add_(mean); return output; } template class normal_kernel, typename RNG> Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, c10::optional gen) { CHECK_NORMAL_TENSOR_STD(std); auto mean_tensor = at::full({}, mean, output.options()); auto shape = at::infer_size(mean_tensor.sizes(), std.sizes()); at::native::resize_output(output, shape); normal_impl_(output, 0, 1, gen); // CUDA NB: addcmul_out copies the tensor to be added into the output. // The previous function here was addcmul_out(output, mean_tensor, output, std, 1); // The third argument is not a constant reference and hence the samples in output are overwritten. // Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std output.mul_(std).add_(mean_tensor); return output; } template class normal_kernel, typename RNG> Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional gen) { CHECK_NORMAL_TENSOR_STD(std); auto shape = at::infer_size(mean.sizes(), std.sizes()); at::native::resize_output(output, shape); normal_impl_(output, 0, 1, gen); // CUDA NB: addcmul_out copies the tensor to be added into the output. // The previous function here was addcmul_out(output, mean, output, std, 1); // The third argument is not a constant reference and hence the samples in output are overwritten. // Consequently, the computation performed is mean + mean * std instead of mean + output * std output.mul_(std).add_(mean); return output; } template class normal_kernel, typename RNG> Tensor normal_impl(const Tensor& mean, double std, c10::optional gen) { CHECK_NORMAL_STD(std); Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous); normal_out_impl(ret, mean, std, gen); return ret; } template class normal_kernel, typename RNG> Tensor normal_impl(double mean, const Tensor& std, c10::optional gen) { CHECK_NORMAL_TENSOR_STD(std); Tensor ret = at::empty_like(std, MemoryFormat::Contiguous); normal_out_impl(ret, mean, std, gen); return ret; } template class normal_kernel, typename RNG> Tensor normal_impl(const Tensor& mean, const Tensor& std, c10::optional gen) { CHECK_NORMAL_TENSOR_STD(std); auto shape = at::infer_size(mean.sizes(), std.sizes()); Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous); normal_out_impl(ret, mean, std, gen); return ret; } // ==================================================== Uniform ======================================================= template class uniform_kernel, typename RNG> at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, c10::optional generator) { if (self.is_complex()) { auto float_tensor = at::view_as_real(self); uniform_impl_(float_tensor, from, to, generator); } else { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] { const auto dtype = self.dtype(); const auto min = static_cast(std::numeric_limits::lowest()); const auto max = static_cast(std::numeric_limits::max()); CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype); CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype); TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to); TORCH_CHECK((to - from) <= std::numeric_limits::max(), "uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()), ">::max(), but found to=", to, " and from=", from, " which result in to-from to exceed the limit"); from = std::min(std::max(from, min), max); to = std::max(std::min(to, max), min); }); auto iter = at::TensorIterator::borrowing_nullary_op(self); uniform_kernel()(iter, from, to, generator); } return self; } // ================================================== LogNormal ======================================================= template class log_normal_kernel, typename RNG> at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, c10::optional gen) { TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std); auto iter = TensorIterator::borrowing_nullary_op(self); log_normal_kernel()(iter, mean, std, gen); return self; } // =================================================== Geometric ====================================================== template class geometric_kernel, typename RNG> Tensor& geometric_impl_(Tensor& self, double p, c10::optional gen) { TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p); auto iter = TensorIterator::borrowing_nullary_op(self); geometric_kernel()(iter, p, gen); return self; } // ================================================== Exponential ===================================================== template class exponential_kernel, typename RNG> Tensor& exponential_impl_(Tensor& self, double lambda, c10::optional gen) { TORCH_CHECK(lambda >= 0.0, "exponential_ expects lambda >= 0.0, but found lambda=", lambda); auto iter = TensorIterator::borrowing_nullary_op(self); exponential_kernel()(iter, lambda, gen); return self; } // ==================================================== Cauchy ======================================================== template class cauchy_kernel, typename RNG> Tensor& cauchy_impl_(Tensor& self, double median, double sigma, c10::optional gen) { auto iter = TensorIterator::borrowing_nullary_op(self); cauchy_kernel()(iter, median, sigma, gen); return self; } // ==================================================== Bernoulli ===================================================== template class bernoulli_tensor_kernel, typename RNG> Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, c10::optional gen) { NoNamesGuard guard; at::assert_no_internal_overlap(self); bernoulli_tensor_kernel()(self, p_, gen); return self; } template class bernoulli_scalar_kernel, typename RNG> Tensor& bernoulli_impl_(Tensor& self, double p, c10::optional gen) { TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p); at::assert_no_internal_overlap(self); bernoulli_scalar_kernel()(self, p, gen); return self; } template class bernoulli_tensor_kernel, typename RNG> Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, c10::optional gen) { // result.resize_as_(self) requires self to have same dtype as result, so we // use resize_ instead. // TODO: Fix resize_as_. See pytorch/pytorch#11665. result.resize_(self.sizes()); bernoulli_impl_(result, self, gen); namedinference::propagate_names(result, self); return result; } #undef CHECK_OUT_OF_BOUNDS #undef WARN_OUT_OF_BOUNDS }}}