namespace at { | |
namespace functorch { | |
// NOTE [functorch TLS in pytorch/pytorch] | |
// | |
// functorch lives out-of-tree. However, it has some TLS that needs to be | |
// propagated. The solution for that is we store a pointer to the TLS | |
// inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to | |
// include whatever functorch needs. | |
// | |
// We need to store a pointer due to the indirection: | |
// inside functorch, we will create a subclass of FunctorchTLSBase called | |
// FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack. | |
// FuncTorchTLSBase doesn't have any metadata because it hasn't been defined | |
// yet. | |
// | |
// Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside | |
// functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*. | |
// We can't directly pass around FunctorchTLSBase (without a pointer) because | |
// FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having | |
// more elements. | |
struct TORCH_API FuncTorchTLSBase { | |
virtual ~FuncTorchTLSBase() = default; | |
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0; | |
// functorch doesn't always work with autograd.Function. | |
// This is a hook to get into functorch -- functorch will determine | |
// if it should raise an error message | |
virtual int64_t checkSupportsAutogradFunction() const = 0; | |
virtual void checkSupportsInplaceRequiresGrad() const = 0; | |
virtual void checkSupportsRetainGrad() const = 0; | |
}; | |
// returns deepcopy of the functorch tls | |
TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS(); | |
// sets the functorch tls. always does a deep copy. | |
TORCH_API void setFuncTorchTLS( | |
const std::shared_ptr<const FuncTorchTLSBase>& state); | |
// get a mutable reference to the functorch tls | |
TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor(); | |
} // namespace functorch | |
} // namespace at | |