|
|
|
import weakref |
|
|
|
import torch |
|
import torch.utils._pytree as pytree |
|
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet |
|
from torch._ops import OpOverload |
|
from torch.library import Library |
|
from torchgen.model import ( |
|
BaseTy, |
|
BaseType, |
|
FunctionSchema, |
|
OperatorName, |
|
OptionalType, |
|
SchemaKind, |
|
) |
|
|
|
from .autograd import autograd_not_implemented |
|
|
|
|
|
def register_functional_op( |
|
lib: Library, |
|
new_op_name: str, |
|
mutable_op: OpOverload, |
|
) -> None: |
|
"""Given a mutable operator, registers the functional variant. |
|
|
|
This API also correctly links the functional variant with the mutable |
|
operator for the purposes of functionalization. |
|
|
|
All of the new registrations are performed on the ``lib`` passed in. |
|
|
|
Arguments: |
|
lib (Library): Should be a torch.library.Library object that has |
|
the same namespace as ``mutable_op``'s namespace. |
|
lib will be used to register the new functional op as well |
|
as a functionalization kernel for the ``mutable_op`` |
|
If you don't have a library handy, use |
|
``torch.library.Library(ns, 'FRAGMENT')`` to construct one. |
|
new_op_name (str): The name of the functional operator (without the |
|
namespace). If no namespace, the new functional variant will be |
|
accessible under ``torch.ops.{lib.ns}.new_op_name``. |
|
mutable_op (OpOverload): The mutable custom operator. Note |
|
that you may need to add a `.default` to it, like |
|
`torch.ops.aten.abs_.default`. |
|
|
|
""" |
|
validate(mutable_op) |
|
schema = functional_schema(new_op_name, mutable_op) |
|
lib.define(schema) |
|
|
|
functional_impl = construct_functional_impl(mutable_op) |
|
lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd') |
|
|
|
functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default |
|
|
|
|
|
|
|
|
|
|
|
|
|
lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd') |
|
|
|
f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op) |
|
|
|
lib.impl(mutable_op, f_kernel, 'Functionalize') |
|
|
|
|
|
def construct_functional_impl(mutable_op): |
|
def functional_impl(*args): |
|
|
|
|
|
|
|
|
|
new_args = [] |
|
extra_rets = [] |
|
for is_write, arg in zip(mutable_args(mutable_op), args): |
|
if is_write: |
|
cloned = arg.clone() if arg is not None else None |
|
new_args.append(cloned) |
|
extra_rets.append(cloned) |
|
else: |
|
new_args.append(arg) |
|
result = mutable_op(*new_args) |
|
if result is None: |
|
return tuple(extra_rets) |
|
if isinstance(result, tuple): |
|
return (*result, *extra_rets) |
|
return (result, *extra_rets) |
|
return functional_impl |
|
|
|
|
|
def construct_functionalization_kernel(mutable_op, functional_op): |
|
def kernel(*args): |
|
|
|
|
|
if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args): |
|
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): |
|
return mutable_op(*args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args): |
|
raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper") |
|
|
|
unwrapped_args = [] |
|
for arg in args: |
|
if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg): |
|
torch._sync(arg) |
|
unwrapped = torch._from_functional_tensor(arg) |
|
unwrapped_args.append(unwrapped) |
|
else: |
|
unwrapped_args.append(arg) |
|
|
|
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): |
|
output = functional_op(*unwrapped_args) |
|
|
|
num_actual_output = len(mutable_op._schema.returns) |
|
actual_output = pytree.tree_map( |
|
torch._to_functional_tensor, output[:num_actual_output]) |
|
|
|
new_values_to_propagate = output[num_actual_output:] |
|
inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args) |
|
if is_write] |
|
assert len(new_values_to_propagate) == len(inputs_to_replace) |
|
for new_value, arg in zip(new_values_to_propagate, inputs_to_replace): |
|
if (arg is None and new_value is None) or (arg is not None and new_value is not None): |
|
continue |
|
torch._C._propagate_xla_data(arg, new_value) |
|
torch._C._replace_(arg, new_value) |
|
torch._C._commit_update(arg) |
|
torch._sync(arg) |
|
|
|
if len(actual_output) == 1: |
|
return actual_output[0] |
|
elif len(actual_output) == 0: |
|
return None |
|
return actual_output |
|
|
|
return kernel |
|
|
|
|
|
def validate(mutable_op: OpOverload): |
|
if not isinstance(mutable_op, OpOverload): |
|
raise TypeError( |
|
f"register_functional_op(mutable_op): expected mutable_op to be instance of " |
|
f"OpOverload but got {type(mutable_op)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
schema = FunctionSchema.parse(str(mutable_op._schema)) |
|
if not schema.kind() == SchemaKind.mutable: |
|
raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)") |
|
for ret in schema.returns: |
|
|
|
if ret.annotation is not None: |
|
raise NotImplementedError( |
|
"NYI: register_functional_op(op) where op returns a mutated or aliased value. " |
|
"Please file an issue (and as a workaround, modify your operator to " |
|
"not return the mutated value or aliases)") |
|
for arg in schema.arguments.flat_all: |
|
|
|
if arg.type.is_tensor_like() and ( |
|
arg.type != BaseType(BaseTy.Tensor) |
|
and arg.type != OptionalType(BaseType(BaseTy.Tensor)) |
|
): |
|
raise NotImplementedError( |
|
"NYI: register_functional_op(op) where op has a List[Tensor] input." |
|
"Please file an issue.") |
|
|
|
|
|
def functional_schema(new_op_name, op: OpOverload): |
|
schema = FunctionSchema.parse(str(op._schema)) |
|
schema = schema.signature().with_name(OperatorName.parse(new_op_name)) |
|
return str(schema) |
|
|
|
|
|
def mutable_args(op: OpOverload): |
|
return tuple(False if arg.alias_info is None else arg.alias_info.is_write |
|
for arg in op._schema.arguments) |
|
|