|
import textwrap |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional, Sequence, Tuple, Union |
|
|
|
from torchgen.api.types import DispatcherSignature |
|
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup |
|
|
|
from torchgen.context import method_with_native_function |
|
from torchgen.model import ( |
|
Argument, |
|
BackendIndex, |
|
BaseTy, |
|
BaseType, |
|
DispatchKey, |
|
FunctionSchema, |
|
ListType, |
|
NativeFunction, |
|
NativeFunctionsGroup, |
|
OperatorName, |
|
OptionalType, |
|
Type, |
|
) |
|
from torchgen.utils import mapMaybe |
|
|
|
base_type_to_c_type = { |
|
BaseTy.Tensor: "AtenTensorHandle", |
|
BaseTy.bool: "int32_t", |
|
BaseTy.int: "int64_t", |
|
BaseTy.SymInt: "int64_t", |
|
BaseTy.Scalar: "double", |
|
BaseTy.float: "double", |
|
BaseTy.str: "const char*", |
|
BaseTy.DeviceIndex: "int32_t", |
|
BaseTy.Layout: "int32_t", |
|
BaseTy.MemoryFormat: "int32_t", |
|
BaseTy.ScalarType: "int32_t", |
|
BaseTy.Generator: "AtenGeneratorHandle", |
|
} |
|
|
|
base_type_to_aten_type = { |
|
BaseTy.Tensor: "at::Tensor", |
|
BaseTy.bool: "bool", |
|
BaseTy.int: "int64_t", |
|
BaseTy.SymInt: "c10::SymInt", |
|
BaseTy.Scalar: "c10::Scalar", |
|
BaseTy.float: "double", |
|
BaseTy.str: "c10::string_view", |
|
BaseTy.DeviceIndex: "c10::DeviceIndex", |
|
BaseTy.Layout: "c10::Layout", |
|
BaseTy.MemoryFormat: "c10::MemoryFormat", |
|
BaseTy.ScalarType: "c10::ScalarType", |
|
BaseTy.Generator: "at::Generator", |
|
} |
|
|
|
base_type_to_callsite_expr = { |
|
BaseTy.Tensor: "*tensor_handle_to_tensor_pointer", |
|
BaseTy.bool: "", |
|
BaseTy.int: "", |
|
BaseTy.SymInt: "", |
|
BaseTy.Scalar: "", |
|
BaseTy.float: "", |
|
BaseTy.str: "", |
|
BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>", |
|
BaseTy.Layout: "static_cast<c10::Layout>", |
|
BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>", |
|
BaseTy.ScalarType: "static_cast<c10::ScalarType>", |
|
BaseTy.Generator: "*generator_handle_to_generator_pointer", |
|
} |
|
|
|
|
|
|
|
def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str], List[str], List[str]]: |
|
if isinstance(typ, BaseType): |
|
if typ.name in base_type_to_c_type: |
|
return ( |
|
[base_type_to_c_type[typ.name]], |
|
[name], |
|
[base_type_to_aten_type[typ.name]], |
|
[ |
|
f"{base_type_to_callsite_expr[typ.name]}({name})" |
|
if base_type_to_callsite_expr[typ.name] |
|
else name |
|
], |
|
) |
|
elif typ.name == BaseTy.Device: |
|
return ( |
|
["int32_t", "int32_t"], |
|
[name, name + "_index_"], |
|
["c10::Device"], |
|
[ |
|
f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))" |
|
], |
|
) |
|
else: |
|
|
|
raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}") |
|
elif isinstance(typ, OptionalType): |
|
c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name( |
|
typ.elem, name |
|
) |
|
j = 0 |
|
new_aten_types = [] |
|
new_callsite_exprs = [] |
|
for aten_type in aten_types: |
|
|
|
c_types[j] = c_types[j] + "*" |
|
if aten_type.startswith("c10::ArrayRef<"): |
|
|
|
new_aten_types.append(f"::std::optional<{aten_type}>") |
|
base_type = aten_type[len("c10::ArrayRef<") : -1] |
|
new_callsite_exprs.append( |
|
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})" |
|
) |
|
j += 2 |
|
elif aten_type == "c10::Device": |
|
|
|
new_aten_types.append("::std::optional<c10::Device>") |
|
new_callsite_exprs.append( |
|
f"pointer_to_optional_device({names[j]}, {names[j+1]})" |
|
) |
|
j += 2 |
|
else: |
|
new_aten_types.append(f"::std::optional<{aten_type}>") |
|
new_callsite_exprs.append( |
|
f"pointer_to_optional<{aten_type}>({names[j]})" |
|
) |
|
j += 1 |
|
|
|
return ( |
|
c_types, |
|
names, |
|
new_aten_types, |
|
new_callsite_exprs, |
|
) |
|
elif isinstance(typ, ListType): |
|
|
|
c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name) |
|
assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ) |
|
|
|
|
|
c_types[0] = f"const {c_types[0]}*" |
|
c_types.append("int64_t") |
|
name = names[0] |
|
names.append(name + "_len_") |
|
|
|
atype = aten_types[0] |
|
callsite_exprs = [] |
|
if atype == "bool": |
|
|
|
|
|
assert typ.size is not None |
|
callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})") |
|
elif atype == "::std::optional<at::Tensor>": |
|
|
|
callsite_exprs.append( |
|
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))" |
|
) |
|
else: |
|
callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)") |
|
|
|
aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types] |
|
return ( |
|
c_types, |
|
names, |
|
aten_types, |
|
callsite_exprs, |
|
) |
|
|
|
|
|
def zip_type_and_name(types: List[str], names: List[str]) -> List[str]: |
|
return [typ + " " + name for typ, name in zip(types, names)] |
|
|
|
|
|
|
|
def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[str]]: |
|
types = [] |
|
new_names = [] |
|
callsite_exprs = [] |
|
for arg in flat_arguments: |
|
new_types, names, _, new_callsite_exprs = convert_arg_type_and_name( |
|
arg.type, arg.name |
|
) |
|
types.extend(new_types) |
|
new_names.extend(names) |
|
callsite_exprs.extend(new_callsite_exprs) |
|
return zip_type_and_name(types, new_names), callsite_exprs |
|
|
|
|
|
|
|
|
|
|
|
def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]: |
|
types = [] |
|
names = [] |
|
for idx, ret in enumerate(schema.returns): |
|
names.append(f"ret{idx}") |
|
if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type: |
|
types.append(base_type_to_c_type[ret.type.name] + "*") |
|
else: |
|
raise NotImplementedError( |
|
f"TODO: add support for return type {repr(ret.type)}" |
|
) |
|
|
|
def convert_return(typ: BaseType, val: str) -> str: |
|
if typ.name == BaseTy.Tensor: |
|
return f"new_tensor_handle(std::move({val}));" |
|
elif typ.name == BaseTy.SymInt: |
|
return f"{val}.expect_int()" |
|
elif typ.name == BaseTy.Scalar: |
|
return f"{val}.toDouble()" |
|
else: |
|
return val |
|
|
|
ret_pointer_can_be_null = False |
|
unambiguous_name = schema.name.unambiguous_name() |
|
for name in [ |
|
"_scaled_dot_product_flash_attention", |
|
"_scaled_dot_product_efficient_attention", |
|
"convolution_backward", |
|
]: |
|
if name in unambiguous_name: |
|
ret_pointer_can_be_null = True |
|
break |
|
|
|
callsite_exprs: List[str] = [] |
|
for idx, ret in enumerate(schema.returns): |
|
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)" |
|
assert isinstance(ret.type, BaseType) |
|
rval = convert_return(ret.type, tmp) |
|
if ret_pointer_can_be_null: |
|
callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}") |
|
else: |
|
callsite_exprs.append(f"*{names[idx]} = {rval};") |
|
|
|
return zip_type_and_name(types, names), callsite_exprs |
|
|
|
|
|
|
|
declaration_definition_cache: Dict[Tuple[str, str, str], Tuple[str, str]] = {} |
|
|
|
|
|
def gen_declaration_and_definition( |
|
schema: FunctionSchema, device: str, backend_call: str |
|
) -> Tuple[str, str]: |
|
func_name = schema.name.unambiguous_name() |
|
|
|
global declaration_definition_cache |
|
if (func_name, device, backend_call) in declaration_definition_cache: |
|
return declaration_definition_cache[(func_name, device, backend_call)] |
|
|
|
if schema.is_out_fn(): |
|
|
|
|
|
args, callsite_exprs = gen_arguments( |
|
[*schema.arguments.out, *schema.arguments.flat_non_out] |
|
) |
|
ret_assignments: List[str] = [] |
|
else: |
|
args, callsite_exprs = gen_arguments(schema.arguments.flat_all) |
|
|
|
ret_declarations, ret_assignments = ( |
|
([], []) if schema.name.name.inplace else gen_returns(schema) |
|
) |
|
args.extend(ret_declarations) |
|
|
|
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})" |
|
|
|
tmp_result = "auto tmp_result = " if ret_assignments else "" |
|
ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else "" |
|
definition = f""" |
|
{declaration} {{ |
|
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{ |
|
{tmp_result}{backend_call}( |
|
{textwrap.indent(', '.join(callsite_exprs), " ")} |
|
);{textwrap.indent(ret_assignments_str, " ")} |
|
}}); |
|
}} |
|
""" |
|
declaration_definition_cache[(func_name, device, backend_call)] = ( |
|
declaration, |
|
definition, |
|
) |
|
return declaration, definition |
|
|
|
|
|
def gen_static_dispatch_backend_call_signature( |
|
sig: Union[CppSignature, DispatcherSignature], |
|
f: NativeFunction, |
|
) -> CppSignature: |
|
sig = DispatcherSignature.from_schema(f.func) |
|
cpp_sigs = CppSignatureGroup.from_native_function( |
|
f, method=False, fallback_binding=False |
|
) |
|
if sig.symint and f.func.has_symint(): |
|
cpp_sig = cpp_sigs.symint_signature |
|
else: |
|
cpp_sig = cpp_sigs.signature |
|
assert cpp_sig is not None |
|
return cpp_sig |
|
|
|
|
|
def gen_static_dispatch_backend_call( |
|
f: NativeFunction, |
|
backend_index: BackendIndex, |
|
) -> str: |
|
sig = DispatcherSignature.from_schema(f.func) |
|
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) |
|
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}" |
|
|
|
|
|
def get_backend_index_for_aoti( |
|
func: NativeFunction, |
|
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], |
|
dispatch_key: DispatchKey, |
|
backend_indices: Dict[DispatchKey, BackendIndex], |
|
) -> Optional[BackendIndex]: |
|
backend_index = None |
|
if backend_indices[dispatch_key].has_kernel(func) or ( |
|
func.structured_delegate is not None |
|
and func.structured_delegate in func_group_mapping |
|
and backend_indices[dispatch_key].has_kernel( |
|
func_group_mapping[func.structured_delegate] |
|
) |
|
): |
|
backend_index = backend_indices[dispatch_key] |
|
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func): |
|
|
|
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd] |
|
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel( |
|
func |
|
): |
|
|
|
backend_index = backend_indices[ |
|
DispatchKey.CompositeExplicitAutogradNonFunctional |
|
] |
|
elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func): |
|
backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd] |
|
|
|
return backend_index |
|
|
|
|
|
def get_header_for_aoti( |
|
func: NativeFunction, |
|
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], |
|
dispatch_key: DispatchKey, |
|
backend_indices: Dict[DispatchKey, BackendIndex], |
|
) -> Optional[str]: |
|
backend_index = get_backend_index_for_aoti( |
|
func, func_group_mapping, dispatch_key, backend_indices |
|
) |
|
return ( |
|
None |
|
if backend_index is None |
|
else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>" |
|
) |
|
|
|
|
|
def get_fallback_op_name(func: NativeFunction) -> str: |
|
return ( |
|
f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}" |
|
if func.func.name.overload_name |
|
else f"{func.namespace}.{func.func.name.name}.default" |
|
) |
|
|
|
|
|
def gen_c_shim( |
|
func: NativeFunction, |
|
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], |
|
dispatch_key: DispatchKey, |
|
backend_indices: Dict[DispatchKey, BackendIndex], |
|
header: bool, |
|
) -> Optional[str]: |
|
backend_index = get_backend_index_for_aoti( |
|
func, func_group_mapping, dispatch_key, backend_indices |
|
) |
|
if backend_index is None: |
|
return None |
|
|
|
schema = func.func |
|
device = dispatch_key.lower() |
|
backend_call = gen_static_dispatch_backend_call( |
|
func, |
|
backend_index, |
|
) |
|
|
|
try: |
|
if header: |
|
declaration, _ = gen_declaration_and_definition( |
|
schema, device, backend_call |
|
) |
|
return f"AOTI_TORCH_EXPORT {declaration};" |
|
else: |
|
_, definition = gen_declaration_and_definition(schema, device, backend_call) |
|
return definition |
|
|
|
except NotImplementedError: |
|
return None |
|
|
|
|
|
@dataclass(frozen=True) |
|
class ShimGenerator: |
|
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup] |
|
dispatch_key: DispatchKey |
|
backend_indices: Dict[DispatchKey, BackendIndex] |
|
header: bool |
|
|
|
@method_with_native_function |
|
def __call__( |
|
self, |
|
func: NativeFunction, |
|
) -> Optional[str]: |
|
result = gen_c_shim( |
|
func, |
|
self.func_group_mapping, |
|
self.dispatch_key, |
|
self.backend_indices, |
|
self.header, |
|
) |
|
return result |
|
|
|
|
|
def gen_aoti_c_shim( |
|
native_functions: Sequence[NativeFunction], |
|
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], |
|
dispatch_key: DispatchKey, |
|
backend_indices: Dict[DispatchKey, BackendIndex], |
|
header: bool, |
|
includes: str = "", |
|
) -> str: |
|
body = "\n".join( |
|
list( |
|
mapMaybe( |
|
ShimGenerator( |
|
func_group_mapping, dispatch_key, backend_indices, header |
|
), |
|
native_functions, |
|
) |
|
) |
|
) |
|
device = dispatch_key.lower() |
|
|
|
warning = """ |
|
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. |
|
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details""" |
|
|
|
if header: |
|
return f""" |
|
{warning} |
|
|
|
#pragma once |
|
|
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
|
|
|
#ifdef __cplusplus |
|
extern "C" {{ |
|
#endif |
|
|
|
{body} |
|
|
|
#ifdef __cplusplus |
|
}} // extern "C" |
|
#endif |
|
""" |
|
|
|
else: |
|
return f""" |
|
{warning} |
|
|
|
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h> |
|
#include <torch/csrc/inductor/aoti_torch/utils.h> |
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS |
|
#include <ATen/{str(dispatch_key)}Functions.h> |
|
#include <ATen/CompositeExplicitAutogradFunctions.h> |
|
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h> |
|
#include <ATen/CompositeImplicitAutogradFunctions.h> |
|
#else |
|
{includes} |
|
#endif |
|
|
|
using namespace torch::aot_inductor; |
|
|
|
{body}""" |
|
|