|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
#include <ATen/ATen.h> |
|
#include <ATen/core/op_registration/op_registration.h> |
|
#include <torch/library.h> |
|
|
|
namespace at { |
|
namespace functorch { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); |
|
|
|
|
|
|
|
TORCH_API bool isVmapFallbackWarningEnabled(); |
|
TORCH_API void setVmapFallbackWarningEnabled(bool enabled); |
|
|
|
|
|
|
|
TORCH_API bool isVmapFallbackEnabled(); |
|
TORCH_API void setVmapFallbackEnabled(bool enabled); |
|
|
|
template <typename A> A vector_to_result(const std::vector<IValue>& buffer) { |
|
return buffer[0].to<A>(); |
|
} |
|
template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) { |
|
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>()); |
|
} |
|
template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) { |
|
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>()); |
|
} |
|
|
|
|
|
|
|
template <typename Ret> |
|
Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) { |
|
std::vector<IValue> stack(args.begin(), args.end()); |
|
batchedTensorForLoopFallback(op, &stack); |
|
return vector_to_result<Ret>(stack); |
|
} |
|
|
|
template <typename A, typename B> |
|
std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) { |
|
std::vector<IValue> stack(args.begin(), args.end()); |
|
batchedTensorForLoopFallback(op, &stack); |
|
return vector_to_result<A, B>(stack); |
|
} |
|
|
|
template <typename A, typename B, typename C> |
|
std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) { |
|
std::vector<IValue> stack(args.begin(), args.end()); |
|
batchedTensorForLoopFallback(op, &stack); |
|
return vector_to_result<A, B, C>(stack); |
|
} |
|
|
|
|
|
} |
|
} |
|
|