|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
#include <ATen/functorch/Macros.h> |
|
#include <c10/core/DispatchKey.h> |
|
#include <ATen/core/function_schema.h> |
|
#include <c10/util/Optional.h> |
|
#include <c10/util/variant.h> |
|
#include <unordered_map> |
|
#include <mutex> |
|
#include <c10/core/impl/LocalDispatchKeySet.h> |
|
#include <ATen/functorch/Interpreter.h> |
|
#include <ATen/functorch/VmapInterpreter.h> |
|
#include <ATen/functorch/ADInterpreters.h> |
|
#include <ATen/functorch/FunctionalizeInterpreter.h> |
|
|
|
|
|
namespace c10 { struct AutogradMetaInterface; } |
|
|
|
namespace at { |
|
namespace functorch { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API DynamicLayer { |
|
explicit DynamicLayer( |
|
TransformType transform_type, |
|
int64_t layerId, |
|
optional<int64_t> batchSize = nullopt, |
|
optional<RandomnessType> randomness = nullopt, |
|
optional<bool> prev_grad_mode = nullopt, |
|
optional<bool> pre_fwd_grad_mode = nullopt, |
|
optional<bool> functionalize_add_back_views = nullopt); |
|
|
|
TransformType key() const; |
|
int64_t layerId() const; |
|
|
|
const Interpreter& interpreter() const { return interpreter_; } |
|
Interpreter& interpreter() { return interpreter_; } |
|
|
|
|
|
int64_t batchSize() const; |
|
RandomnessType randomness() const; |
|
|
|
private: |
|
Interpreter interpreter_; |
|
}; |
|
|
|
TORCH_API int64_t initAndPushDynamicLayer( |
|
TransformType transform_type, |
|
optional<int64_t> batch_size = nullopt, |
|
optional<RandomnessType> randomness = nullopt, |
|
optional<bool> prev_grad_mode = nullopt, |
|
optional<bool> prev_fwd_grad_mode = nullopt, |
|
optional<bool> functionalize_add_back_views = nullopt); |
|
TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata(); |
|
TORCH_API c10::optional<DynamicLayer> maybeCurrentDynamicLayer(); |
|
TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack(); |
|
TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack); |
|
TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included); |
|
|
|
|
|
|
|
TORCH_API bool areTransformsActive(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_API std::shared_ptr<bool> getLifeHandleForLevel(int64_t level); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema); |
|
|
|
|
|
TORCH_API c10::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input); |
|
|
|
TORCH_API Tensor unwrapIfDead(const Tensor& tensor); |
|
|
|
|
|
TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); |
|
TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack); |
|
|
|
|
|
|
|
TORCH_API void setInplaceRequiresGradAllowed(bool allowed); |
|
TORCH_API bool getInplaceRequiresGradAllowed(); |
|
|
|
} |
|
} |
|
|