File size: 5,210 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# mypy: allow-untyped-defs
"""
The APIs in this file are exposed as `functorch.*`. They are thin wrappers
around the torch.func.* APIs that have deprecation warnings -- we're trying
to move people to the torch.func.* equivalents.
NB: We don't use *args, **kwargs in the signatures because that changes the
documentation.
"""
import textwrap
import warnings
from typing import Any, Callable, Optional, Tuple, Union
import torch._functorch.apis as apis
import torch._functorch.eager_transforms as _impl
import torch._functorch.make_functional as _nn_impl
import torch.nn as nn
from torch._functorch.eager_transforms import argnums_t
from torch._functorch.vmap import in_dims_t, out_dims_t
def get_warning(api, new_api=None, replace_newlines=False):
if new_api is None:
new_api = f"torch.func.{api}"
warning = (
f"We've integrated functorch into PyTorch. As the final step of the \n"
f"integration, `functorch.{api}` is deprecated as of PyTorch \n"
f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n"
f"and/or the `torch.func` migration guide for more details \n"
f"https://pytorch.org/docs/main/func.migrating.html"
)
if replace_newlines:
warning = warning.replace("\n", "")
return warning
def warn_deprecated(api, new_api=None):
warning = get_warning(api, new_api, replace_newlines=True)
warnings.warn(warning, FutureWarning, stacklevel=3)
def setup_docs(functorch_api, torch_func_api=None, new_api_name=None):
api_name = functorch_api.__name__
if torch_func_api is None:
torch_func_api = getattr(_impl, api_name)
# See https://docs.python.org/3/using/cmdline.html#cmdoption-OO
if torch_func_api.__doc__ is None:
return
warning = get_warning(api_name, new_api_name)
warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ")
warning_note = textwrap.indent(warning_note, " ")
functorch_api.__doc__ = torch_func_api.__doc__ + warning_note
def vmap(
func: Callable,
in_dims: in_dims_t = 0,
out_dims: out_dims_t = 0,
randomness: str = "error",
*,
chunk_size=None,
) -> Callable:
warn_deprecated("vmap", "torch.vmap")
return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)
def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
warn_deprecated("grad")
return apis.grad(func, argnums, has_aux)
def grad_and_value(
func: Callable, argnums: argnums_t = 0, has_aux: bool = False
) -> Callable:
warn_deprecated("grad_and_value")
return apis.grad_and_value(func, argnums, has_aux)
def vjp(func: Callable, *primals, has_aux: bool = False):
warn_deprecated("vjp")
return _impl.vjp(func, *primals, has_aux=has_aux)
def jvp(
func: Callable,
primals: Any,
tangents: Any,
*,
strict: bool = False,
has_aux: bool = False,
):
warn_deprecated("jvp")
return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux)
def jacrev(
func: Callable,
argnums: Union[int, Tuple[int]] = 0,
*,
has_aux=False,
chunk_size: Optional[int] = None,
_preallocate_and_copy=False,
):
warn_deprecated("jacrev")
return _impl.jacrev(
func,
argnums,
has_aux=has_aux,
chunk_size=chunk_size,
_preallocate_and_copy=_preallocate_and_copy,
)
def jacfwd(
func: Callable,
argnums: argnums_t = 0,
has_aux: bool = False,
*,
randomness: str = "error",
):
warn_deprecated("jacfwd")
return _impl.jacfwd(func, argnums, has_aux, randomness=randomness)
def hessian(func, argnums=0):
warn_deprecated("hessian")
return _impl.hessian(func, argnums=argnums)
def functionalize(func: Callable, *, remove: str = "mutations") -> Callable:
warn_deprecated("functionalize")
return _impl.functionalize(func, remove=remove)
def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
warn_deprecated("make_functional", "torch.func.functional_call")
return _nn_impl.make_functional(model, disable_autograd_tracking)
def make_functional_with_buffers(
model: nn.Module, disable_autograd_tracking: bool = False
):
warn_deprecated("make_functional_with_buffers", "torch.func.functional_call")
return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking)
def combine_state_for_ensemble(models):
warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state")
return _nn_impl.combine_state_for_ensemble(models)
setup_docs(vmap, apis.vmap, "torch.vmap")
setup_docs(grad, apis.grad)
setup_docs(grad_and_value, apis.grad_and_value)
setup_docs(vjp)
setup_docs(jvp)
setup_docs(jacrev)
setup_docs(jacfwd)
setup_docs(hessian)
setup_docs(functionalize)
setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call")
setup_docs(
make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call"
)
setup_docs(
combine_state_for_ensemble,
_nn_impl.combine_state_for_ensemble,
"torch.func.stack_module_state",
)
|