|
|
|
import torch |
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode |
|
from torch.overrides import TorchFunctionMode |
|
|
|
|
|
class AutogradStateOpsFailSafeguard(TorchFunctionMode): |
|
""" |
|
Detect grad state ops during exporting the graph and fail the process by |
|
raising an error, to avoid unexpected behavior. Those grad mode ops could be: |
|
`torch.no_grad` |
|
`torch.enable_grad` |
|
`torch.set_grad_enabled` |
|
|
|
Export with predispatch mode is exempted. |
|
""" |
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None): |
|
kwargs = kwargs or {} |
|
unsupported_grad_mode_ops = [ |
|
torch._C._set_grad_enabled, |
|
] |
|
|
|
|
|
current_state = torch._C.is_grad_enabled() |
|
if func in unsupported_grad_mode_ops: |
|
assert len(args) == 1 |
|
changed_state = args[0] |
|
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) |
|
|
|
|
|
if ( |
|
mode |
|
and isinstance(mode, ProxyTorchDispatchMode) |
|
and not mode.pre_dispatch |
|
and changed_state != current_state |
|
): |
|
raise RuntimeError( |
|
f"Encountered autograd state manager op {func} trying to change global autograd state " |
|
"while exporting. This is unsafe because we don't capture this op in torch.export " |
|
"today, hence we can't reflect the user intention soundly. You can fix this by " |
|
"adding a torch.no_grad() context around the export call." |
|
) |
|
return func(*args, **kwargs) |
|
|