File size: 1,871 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
from torch.overrides import TorchFunctionMode


class AutogradStateOpsFailSafeguard(TorchFunctionMode):
    """

    Detect grad state ops during exporting the graph and fail the process by

    raising an error, to avoid unexpected behavior. Those grad mode ops could be:

    `torch.no_grad`

    `torch.enable_grad`

    `torch.set_grad_enabled`



    Export with predispatch mode is exempted.

    """

    def __torch_function__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        unsupported_grad_mode_ops = [
            torch._C._set_grad_enabled,
        ]
        # It's only enabled while tracing, by confirming the torch dispatch mode is
        # any active PROXY. This is to allow the autograd ops out of tracing.
        current_state = torch._C.is_grad_enabled()
        if func in unsupported_grad_mode_ops:
            assert len(args) == 1
            changed_state = args[0]
            mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
            # Intend to check if it's not the pre_dispatch mode. It's allowed to use
            # autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
            if (
                mode
                and isinstance(mode, ProxyTorchDispatchMode)
                and not mode.pre_dispatch
                and changed_state != current_state
            ):
                raise RuntimeError(
                    f"Encountered autograd state manager op {func} trying to change global autograd state "
                    "while exporting. This is unsafe because we don't capture this op in torch.export "
                    "today, hence we can't reflect the user intention soundly."
                )
        return func(*args, **kwargs)