# mypy: ignore-errors import copy import logging import os import pickle import random from contextlib import contextmanager from functools import partial from typing import Callable, Union import sympy import torch from torch import SymInt import torch.fx as fx import torch.nn as nn from torch._decomp import get_decompositions from torch.fx.experimental.symbolic_shapes import bind_symbols from .aot_autograd import aot_function, aot_module, make_boxed_compiler from .compile_utils import strip_overloads from .partitioners import ( default_partition, draw_graph, min_cut_rematerialization_partition, ) import torch.utils._pytree as pytree log = logging.getLogger(__name__) # These canonicalizations are needed here (and not decompositions), as the ops # we're trying to canonicalize to CompositeImplicitAutograd. def _canonicalize(fx_g): for node in fx_g.graph.nodes: if node.target == torch.ops.aten._to_copy: node.target = torch.ops.aten.to fx_g.recompile() return fx_g @contextmanager def _disable_jit_autocast(): old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) try: yield finally: torch._C._jit_set_autocast_mode(old_jit_autocast_flag) @make_boxed_compiler def ts_compile(fx_g: fx.GraphModule, inps) -> Callable: """ Compiles the :attr:`fx_g` with Torchscript compiler. .. warning:: This API is experimental and likely to change. Args: fx_g(fx.GraphModule): The input Fx graph module to be compiled. Returns: Torch scripted model. """ with _disable_jit_autocast(): strip_overloads(fx_g) for node in fx_g.graph.nodes: if ( node.target == torch.ops.aten._to_copy and len(node.args) == 1 and len(node.kwargs) == 1 and "dtype" in node.kwargs ): node.target = torch.ops.aten.to for node in fx_g.graph.nodes: new_kwargs = {} for k, v in node.kwargs.items(): if isinstance(v, torch.device): v = v.type new_kwargs[k] = v node.kwargs = new_kwargs fx_g.graph.lint() fx_g.recompile() f = torch.jit.script(fx_g) torch._C._jit_pass_remove_mutation(f.graph) f = torch.jit.freeze(f.eval()) f = torch.jit.optimize_for_inference(f) if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps): f(*inps) return f def _draw_graph_compile(fx_g, _, name, clear_meta=True): print(fx_g.code) draw_graph(fx_g, name, clear_meta=clear_meta) return fx_g def draw_graph_compile(name): return make_boxed_compiler( partial(_draw_graph_compile, name=name) ) @make_boxed_compiler def nop(fx_g: fx.GraphModule, _) -> Callable: """ Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler and can be used to check accuracy. .. warning:: This API is experimental and likely to change. """ return fx_g class DebugInterpreter(fx.Interpreter): def run(self, *args): self.symbol_mapping = bind_symbols(self.module, *args) super().run(*args) def run_node(self, n): def subst_symint(ni): if not isinstance(ni, SymInt): return ni r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping)) assert r.is_number, r return int(r) def subst_symint_tuple(nis): return tuple(subst_symint(ni) for ni in nis) def check_significant_strides(a, b): if subst_symint(a.numel()) > 0: for idx in range(a.ndim): if subst_symint(a.stride(idx)) != b.stride(idx) and subst_symint(a.size(idx)) > 1: return False return True def check(nv, rv, desc): assert callable(desc) assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}" assert subst_symint_tuple(nv.size()) == rv.size(), \ f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" same_strides = check_significant_strides(nv, rv) assert same_strides, f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" r = super().run_node(n) if 'val' in n.meta: n_vals, n_spec = pytree.tree_flatten(n.meta['val']) r_vals, r_spec = pytree.tree_flatten(r) # TODO: There is some sort of problem where we record that an # operator returned a tuple/list, and then later it turns out the # real version of the operator returned a list/tuple. Need to # figure out what's actually going on here, the error itself is # harmless enough as we only getitem out the outputs. # assert n_spec == r_spec, f"{n_spec} != {r_spec}" assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): if not isinstance(rv, torch.Tensor): continue check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}") return r @make_boxed_compiler def debug_nop(fx_g: fx.GraphModule, _) -> Callable: """ Returns a (slow) interpreter over the FX graph module that also checks various debugging properties (e.g., that tracing strides matched real strides.) """ return DebugInterpreter(fx_g).run @make_boxed_compiler def simple_ts_compile(fx_g, _): strip_overloads(fx_g) f = torch.jit.script(fx_g) f = torch.jit.freeze(f.eval()) return f def nnc_jit(f): return aot_function(f, simple_ts_compile) aten = torch.ops.aten default_decompositions = { aten.detach, aten.gelu_backward, aten.leaky_relu_backward, aten.sigmoid_backward, aten.threshold_backward, aten.hardtanh_backward, aten.hardsigmoid_backward, aten.hardswish_backward, aten.tanh_backward, aten.silu_backward, aten.elu_backward, aten.cudnn_batch_norm, aten.cudnn_batch_norm_backward, aten.masked_fill.Scalar, aten.masked_fill.Tensor, aten.elu, aten.leaky_relu, aten.hardtanh, aten.hardswish, aten.hardsigmoid, aten.conj_physical, aten.is_same_size, } default_decompositions = get_decompositions(default_decompositions) @make_boxed_compiler def print_compile(fx_g, _): print(fx_g.code) return fx_g def memory_efficient_fusion( fn: Union[Callable, nn.Module], **kwargs, ): """ Wrapper function over :func:`aot_function` and :func:`aot_module` to perform memory efficient fusion. It uses the :func:`min_cut_rematerialization_partition` partitioner to perform efficient recomputation. It uses NVFuser to compile the generated forward and backward graphs. .. warning:: This API is experimental and likely to change. Args: fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` that takes one ore more arguments. Must return one or more Tensors. **kwargs: Any other overrides you want to make to the settings Returns: Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior of the original :attr:`fn`, but whose forward and backward graphs have gone through recomputation optimizations, and the graphs have been compiled with nvfuser. """ config = { "fw_compiler": ts_compile, "bw_compiler": ts_compile, "partition_fn": min_cut_rematerialization_partition, "decompositions": default_decompositions, } config.update(kwargs) if isinstance(fn, torch.nn.Module): return aot_module(fn, **config) else: return aot_function(fn, **config) def debug_compile(fx_g, inps): fx_g.to_folder("foo") print( f""" ############################################################## # To minimize FX graph, copy and paste the below and run it # ############################################################## import torch import torch.fx as fx from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess inps = {[(i.shape, i.dtype) for i in inps]} inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] from foo import FxModule mod = FxModule().cuda() with torch.jit.fuser("fuser2"): # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess) """ ) from foo import FxModule FxModule().cuda()(*inps) return ts_compile(fx_g, inps) graph_index = 0 def get_inputs(input_data_path): """ Return a random input for the given inputs meta generated from _save_fx_default. """ inputs = [] with (open(input_data_path, "rb")) as f: inputs_meta = pickle.load(f) inputs = [] for meta in inputs_meta: if len(meta) == 1: type = meta input = type(random.rand()) else: type, shape, stride, dtype, device = meta if dtype in { torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8, int, float, }: input = torch.randint(0, 1, shape, dtype=dtype, device=device) else: input = torch.rand(shape, dtype=dtype, device=device) inputs.append(input) return inputs def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs): """ The forward, backward, and joint computation graph will be stored in {folder_name}/{current_name}/{current_name}_forward_{graph_index}, {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively. The input shape of the graphs will be stored in the .input files. These files can be loaded with pickle, and is a list of format (type, shape, stride, dtype, device). In the case of type = int or float, it is just (type,). For joint graph input, it is a nested list [[],[]] where the two inner lists have the same format. If dump_example_input is True, example_inputs will be stored in .pt file. Since each function might produce multiple graphs, the graph_index is used to distinguish difference graphs """ from functorch.compile import aot_module_simplified def get_input_meta(args): input_meta = [] if len(args) > 0 and isinstance(args[0], tuple): # joint input input_meta += get_input_meta(args[0]) input_meta += get_input_meta(args[1]) return input_meta for arg in args: if type(arg) == int or type(arg) == float: input_meta.append((type(arg),)) else: input_meta.append( (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device) ) return input_meta def graph_saver_helper(gm_to_save, args, type_name): global graph_index if len(gm_to_save.graph.nodes) == 0: log.log( logging.WARNING, "No nodes in graph {%s}_{%s}_{%s}.", current_name, type_name, graph_index, ) return gm = copy.deepcopy(gm_to_save) gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen gm.recompile() input_meta = get_input_meta(args) os.makedirs(f"{folder_name}/{current_name}", exist_ok=True) gm.to_folder( f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}" ) pickle.dump( input_meta, open( f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950 "wb", ), ) # noqa: E501 if dump_example_input: torch.save( args, f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950 ) # noqa: E501 def graph_saver_forward(gm, fw_args): graph_saver_helper(gm, fw_args, "forward") return gm def graph_saver_backward(gm, bw_args): graph_saver_helper(gm, bw_args, "backward") global graph_index graph_index += 1 return gm def graph_saver_joint(gm, joint_args): graph_saver_helper(gm, joint_args, "joint") return default_partition(gm, joint_args) return aot_module_simplified( gm, example_inputs, fw_compiler=graph_saver_forward, bw_compiler=graph_saver_backward, partition_fn=graph_saver_joint, decompositions=default_decompositions, ) # WARNING: This isn't tested anywhere!! def graph_dumper_aot(current_name, folder_name, dump_example_input=False): """ Dump the forward, backward, and joint computation graph. Example Usage: save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False) optimize_ctx = torchdynamo.optimize( save_fx_func ) with torch.enable_grad(): with optimize_ctx: result = forward_and_backward_pass(model, example_inputs) """ global graph_index graph_index = 0 return partial(_save_fx_default, current_name, folder_name, dump_example_input)