Spaces:
Running
Running
from typing import Callable, Optional | |
from torch._prims.context import TorchRefsMode | |
from torch.fx import GraphModule | |
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx | |
def execute( | |
gm: GraphModule, | |
*args, | |
executor: str = "aten", | |
executor_parameters: Optional[dict] = None, | |
): | |
""" | |
Prototype ATen executor. | |
Just executes the context's graph. | |
""" | |
if executor == "aten": | |
return gm.forward(*args) | |
msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten." | |
raise ValueError(msg) | |
def make_traced(fn: Callable): | |
""" | |
Returns a function that, when called, will | |
trace its torch operations to prims and then | |
execute those prims on the requested trace executor | |
(possibly lowering them to that trace executor first). | |
Only supports the torch operations defined in _torch_to_reference_map | |
in context.py and operations with positional args. All args must | |
be tensors. | |
In the near future all these restrictions will be lifted. | |
Example usage: | |
def foo(a, b): | |
return torch.add(a, b) | |
traced_foo = make_traced(foo) | |
a = torch.randn((1, 2, 3, 4, 5), device='cuda') | |
b = torch.randn((1, 2, 3, 4, 5), device='cuda') | |
result = traced_foo(a, b, executor='aten') | |
""" | |
def _traced(*args, executor="aten", **kwargs): | |
# TODO: caching | |
wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs) | |
with TorchRefsMode(): | |
gm = make_fx(wrapped)(all_args) | |
return execute(gm, all_args, executor=executor) | |
return _traced | |