namespace at { | |
namespace functionalization { | |
// See Note [Functionalization Pass In Core] | |
// ViewMeta is a class used by the functionalization pass to navigate between | |
// a base tensor and a view tensor. | |
// For example, if I call `b = a.view1(...)` | |
// the functionalization pass will generate and store a ViewMeta on b that looks | |
// like: | |
// | |
// ViewMeta( | |
// [<captures>](const Tensor& base, int64_t mutated_view_idx) { | |
// return base.view1(...); | |
// }, | |
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view, | |
// int64_t mutated_view_idx) -> at::Tensor { | |
// return at::functionalization::impl::view1_inverse(base, mutated_view, | |
// ...); | |
// } | |
// | |
// The forward_fn lambda describes how to replay view1 on a tensor. | |
// | |
// The reverse_fn lambda describes how, given a tensor that is already a view, | |
// how to get the corresponding base tensor. See Note [Functionalization Pass: | |
// View Inverses] for details. | |
struct ViewMeta { | |
ViewMeta( | |
std::function<Tensor(const Tensor&, int64_t)> forward, | |
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse, | |
int64_t out_idx = 0) | |
: forward_fn(forward), reverse_fn(reverse), out_index(out_idx) {} | |
std::function<Tensor(const Tensor&, int64_t)> forward_fn; | |
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn; | |
// See Note [out_idx in ViewMeta] | |
int64_t out_index; | |
// Returns a copy of the current ViewMeta, if out_idx matches the current | |
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse | |
// functions, but a new out index. | |
ViewMeta to_out_idx(int64_t out_idx); | |
}; | |
// Alias represents the state shared by (potentially multiple) views of the same | |
// tensor. For example, in the following code: | |
// | |
// b = a.view1(...) | |
// c = b.view2(...) | |
// b.add_(1) | |
// --> alias.add_update(b, {view1_meta}) | |
// | |
// The call to add_(1) will result in a call to alias.add_update(b, | |
// {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose | |
// c is used in an expression (e.g. you try to print c, or pass it to an | |
// operator). Doing so will involve "syncing" c. First we apply any pending | |
// updates to the alias, and then we regenerate c by replaying its views off of | |
// the updated alias. E.g: | |
// | |
// print(str(c)) | |
// --> c.sync_() | |
// --> alias.apply_updates() // after this, the alias will be updated to | |
// reflect the mutation to b | |
class Alias { | |
public: | |
struct Update { | |
const at::Tensor new_val; | |
const std::vector<ViewMeta> view_metas; | |
}; | |
explicit Alias(const at::Tensor& base); | |
const at::Tensor& base() const; | |
size_t generation() const { | |
return generation_; | |
} | |
void add_update( | |
const at::Tensor& updated_val, | |
const std::vector<ViewMeta>& metas); | |
bool apply_updates(); | |
private: | |
// NB: base_ should always point to a tensor BELOW the current | |
// functionalization layer. This is mainly to avoid reference cycles. e.g. | |
// given `b = a.view(...)` Both a.storage_ and b.storage_ are a | |
// FunctionStorageImpl containing an Alias, with contains a Tensor `base_`. In | |
// this case (where a and b are FunctionalTensorWrapper's), base_ should point | |
// not to a, but to a's unwrapped value, a.value_` See Note | |
// [Functionalization: Alias Removal] for a diagram that shows this visually. | |
at::Tensor base_; | |
std::vector<Update> updates_; | |
// generation_ gets incremented every time a mutation is queued onto the | |
// alias. It is used to determine if a given tensor is "up to date", or if it | |
// needs to be regenerated from the alias. | |
size_t generation_ = 0; | |
}; | |
// FunctionalStorageImpl is a subclass of StorageImpl used by the | |
// functionalization pass. It has no underlying data (similar to meta storage). | |
// It also knows how to reflect mutations to tensors in the absence of a valid | |
// data pointer. It does this by separately storing an Alias object, which knows | |
// how to reflect mutations that may have happened to views of the original | |
// tensor. | |
struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { | |
explicit FunctionalStorageImpl(const Tensor& value); | |
void add_update( | |
const Tensor& updated_val, | |
const std::vector<ViewMeta>& view_metas); | |
bool apply_updates(); | |
const Tensor& base(); | |
size_t generation() const; | |
~FunctionalStorageImpl() override = default; | |
private: | |
at::functionalization::Alias alias_; | |
}; | |
} // namespace functionalization | |
} // namespace at | |