|
#pragma once |
|
|
|
#include <ATen/functorch/Macros.h> |
|
#include <ATen/core/dispatch/Dispatcher.h> |
|
#include <c10/core/impl/LocalDispatchKeySet.h> |
|
#include <c10/util/Optional.h> |
|
#include <c10/util/variant.h> |
|
#include <bitset> |
|
|
|
namespace at { namespace functorch { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum RandomnessType { |
|
Error, |
|
Same, |
|
Different, |
|
END |
|
}; |
|
|
|
enum class TransformType { |
|
Torch, |
|
Vmap, |
|
Grad, |
|
Jvp, |
|
Functionalize, |
|
}; |
|
|
|
std::ostream& operator<<(std::ostream& os, const TransformType& t); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct VmapInterpreterMeta { |
|
explicit VmapInterpreterMeta(int64_t batchSize, RandomnessType randomness) : |
|
batchSize_(batchSize), randomness_(randomness) {} |
|
int64_t batchSize_; |
|
RandomnessType randomness_; |
|
}; |
|
|
|
struct GradInterpreterMeta { |
|
explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {} |
|
bool prevGradMode_; |
|
}; |
|
|
|
struct JvpInterpreterMeta { |
|
explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {} |
|
bool prevFwdGradMode_; |
|
}; |
|
|
|
struct FunctionalizeInterpreterMeta { |
|
explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) : |
|
functionalizeAddBackViews_(functionalizeAddBackViews) {} |
|
bool functionalizeAddBackViews_; |
|
}; |
|
|
|
typedef c10::variant< |
|
int64_t, |
|
GradInterpreterMeta, |
|
JvpInterpreterMeta, |
|
VmapInterpreterMeta, |
|
FunctionalizeInterpreterMeta |
|
> InterpreterMeta; |
|
|
|
|
|
struct Interpreter { |
|
|
|
static Interpreter Vmap(int64_t level, int64_t batchSize, RandomnessType randomness) { |
|
return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(batchSize, randomness)); |
|
} |
|
static Interpreter Grad(int64_t level, bool prevGradMode) { |
|
return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode)); |
|
} |
|
static Interpreter Jvp(int64_t level, bool prevFwdGradMode) { |
|
return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode)); |
|
} |
|
static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) { |
|
return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews)); |
|
} |
|
|
|
|
|
TransformType key() const { return type_; } |
|
int64_t level() const { return level_; } |
|
const InterpreterMeta& meta() const { return meta_; } |
|
|
|
void process(const c10::OperatorHandle& op, torch::jit::Stack* stack); |
|
void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); |
|
|
|
void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) { |
|
TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value()); |
|
savedLocalDispatchKeySet_ = std::move(keyset); |
|
} |
|
void clearSavedLocalDispatchKeySet() { |
|
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); |
|
savedLocalDispatchKeySet_ = c10::nullopt; |
|
} |
|
c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const { |
|
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value()); |
|
return *savedLocalDispatchKeySet_; |
|
} |
|
|
|
|
|
explicit Interpreter() = default; |
|
|
|
private: |
|
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta): |
|
type_(type), level_(level), meta_(meta) {} |
|
|
|
|
|
TransformType type_; |
|
int64_t level_; |
|
optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_; |
|
InterpreterMeta meta_; |
|
}; |
|
|
|
|
|
|
|
|
|
void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end, |
|
std::function<Tensor(const Tensor&)> func); |
|
|
|
|
|
|
|
|
|
|
|
|
|
void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end, |
|
const std::bitset<64> use_flag_relative, std::function<Tensor(const Tensor&, bool)> func); |
|
|
|
std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end); |
|
|
|
DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key); |
|
|
|
void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include); |
|
|
|
void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack); |
|
|
|
}} |
|
|