File size: 4,603 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""
To run this file by hand from the root of the PyTorch
repository, run:
python -m tools.autograd.gen_autograd \
aten/src/ATen/native/native_functions.yaml \
aten/src/ATen/native/tags.yaml \
$OUTPUT_DIR \
tools/autograd
Where $OUTPUT_DIR is where you would like the files to be
generated. In the full build system, OUTPUT_DIR is
torch/csrc/autograd/generated/
"""
# gen_autograd.py generates C++ autograd functions and Python bindings.
#
# It delegates to the following scripts:
#
# gen_autograd_functions.py: generates subclasses of torch::autograd::Node
# gen_variable_type.py: generates VariableType.h which contains all tensor methods
# gen_python_functions.py: generates Python bindings to THPVariable
#
import argparse
import os
from typing import List
from torchgen.api import cpp
from torchgen.api.autograd import (
match_differentiability_info,
NativeFunctionWithDifferentiabilityInfo,
)
from torchgen.gen import parse_native_yaml
from torchgen.selective_build.selector import SelectiveBuilder
from . import gen_python_functions
from .gen_autograd_functions import (
gen_autograd_functions_lib,
gen_autograd_functions_python,
)
from .gen_inplace_or_view_type import gen_inplace_or_view_type
from .gen_trace_type import gen_trace_type
from .gen_variable_factories import gen_variable_factories
from .gen_variable_type import gen_variable_type
from .gen_view_funcs import gen_view_funcs
from .load_derivatives import load_derivatives
def gen_autograd(
native_functions_path: str,
tags_path: str,
out: str,
autograd_dir: str,
operator_selector: SelectiveBuilder,
disable_autograd: bool = False,
) -> None:
# Parse and load derivatives.yaml
differentiability_infos, used_dispatch_keys = load_derivatives(
os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
)
template_path = os.path.join(autograd_dir, "templates")
native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
fns = sorted(
filter(
operator_selector.is_native_function_selected_for_training, native_funcs
),
key=lambda f: cpp.name(f.func),
)
fns_with_diff_infos: List[
NativeFunctionWithDifferentiabilityInfo
] = match_differentiability_info(fns, differentiability_infos)
# Generate VariableType.h/cpp
if not disable_autograd:
gen_variable_type(
out,
native_functions_path,
tags_path,
fns_with_diff_infos,
template_path,
used_dispatch_keys,
)
gen_inplace_or_view_type(
out, native_functions_path, tags_path, fns_with_diff_infos, template_path
)
# operator filter not applied as tracing sources are excluded in selective build
gen_trace_type(out, native_funcs, template_path)
# Generate Functions.h/cpp
gen_autograd_functions_lib(out, differentiability_infos, template_path)
# Generate variable_factories.h
gen_variable_factories(out, native_functions_path, tags_path, template_path)
# Generate ViewFuncs.h/cpp
gen_view_funcs(out, fns_with_diff_infos, template_path)
def gen_autograd_python(
native_functions_path: str,
tags_path: str,
out: str,
autograd_dir: str,
) -> None:
differentiability_infos, _ = load_derivatives(
os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
)
template_path = os.path.join(autograd_dir, "templates")
# Generate Functions.h/cpp
gen_autograd_functions_python(out, differentiability_infos, template_path)
# Generate Python bindings
deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
gen_python_functions.gen(
out, native_functions_path, tags_path, deprecated_path, template_path
)
def main() -> None:
parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
parser.add_argument(
"native_functions", metavar="NATIVE", help="path to native_functions.yaml"
)
parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml")
parser.add_argument("out", metavar="OUT", help="path to output directory")
parser.add_argument(
"autograd", metavar="AUTOGRAD", help="path to autograd directory"
)
args = parser.parse_args()
gen_autograd(
args.native_functions,
args.tags,
args.out,
args.autograd,
SelectiveBuilder.get_nop_selector(),
)
if __name__ == "__main__":
main()
|