# mypy: allow-untyped-defs | |
import contextlib | |
from typing import Tuple, Union | |
import torch | |
from torch._C._functorch import ( | |
get_single_level_autograd_function_allowed, | |
set_single_level_autograd_function_allowed, | |
unwrap_if_dead, | |
) | |
from torch.utils._exposed_in import exposed_in | |
__all__ = [ | |
"exposed_in", | |
"argnums_t", | |
"enable_single_level_autograd_function", | |
"unwrap_dead_wrappers", | |
] | |
def enable_single_level_autograd_function(): | |
try: | |
prev_state = get_single_level_autograd_function_allowed() | |
set_single_level_autograd_function_allowed(True) | |
yield | |
finally: | |
set_single_level_autograd_function_allowed(prev_state) | |
def unwrap_dead_wrappers(args): | |
# NB: doesn't use tree_map_only for performance reasons | |
result = tuple( | |
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args | |
) | |
return result | |
argnums_t = Union[int, Tuple[int, ...]] | |