File size: 7,811 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
from dataclasses import dataclass
from typing import Callable, List, Sequence, Tuple
from torchgen.api.types import Binding, CType, NamedCType
from torchgen.model import (
Argument,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Type,
)
connector = "\n\t"
# Return unboxing function name for a NativeFunction
def name(f: NativeFunction) -> str:
return f.func.name.unambiguous_name()
@dataclass(frozen=True)
class Unboxing:
"""
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
A sample generated code:
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
void mul_out(EValue** stack) {
EValue& self = *stack[0];
EValue& other = *stack[1];
EValue& out = *stack[2];
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
torch::executor::mul_outf(self_base, other_base, out_base);
}
"""
# this is a callable that converts a JIT argument, into its C++ type.
# Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
argument_type_gen: Callable[
...,
NamedCType,
]
# Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(
self, args: Sequence[Binding]
) -> Tuple[List[Binding], List[str]]:
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
binding_list = []
for arg in args:
# expecting only Argument
if not isinstance(arg.argument, Argument):
raise Exception( # noqa: TRY002
f"Unexpected argument type, expecting `Argument` but got {arg}"
)
argument: Argument = arg.argument
unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
argument.type, argument.name, mutable=argument.is_write
)
code_list.extend(decl)
code_list.extend(code)
binding_list.append(arg.with_name(unboxed_name))
return binding_list, code_list
def argumenttype_evalue_convert(
self, t: Type, arg_name: str, *, mutable: bool = False
) -> Tuple[str, CType, List[str], List[str]]:
"""
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
(1) the C++ code necessary to unbox the argument
(2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
:param t: a `Type` of an argument
:param arg_name: argument name
:param mutable: boolean for whether this argument type is mutable
:return: unboxed result
"""
ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
if isinstance(t, BaseType):
out_name = f"{arg_name}_base"
code, decl = self._gen_code_base_type(
arg_name=arg_name, out_name=out_name, ctype=ctype
)
elif isinstance(t, OptionalType):
out_name = f"{arg_name}_opt_out"
code, decl = self._gen_code_optional_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
elif isinstance(t, ListType):
out_name = f"{arg_name}_list_out"
code, decl = self._gen_code_list_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
else:
raise Exception( # noqa: TRY002
f"Cannot handle type {t}. arg_name: {arg_name}"
) # noqa: TRY002
return out_name, ctype, code, decl
def _gen_code_base_type(
self, arg_name: str, out_name: str, ctype: CType
) -> Tuple[List[str], List[str]]:
return [
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
def _gen_code_optional_type(
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_opt_in"
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
t.elem, in_name
)
return (
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
""".split(
"\n"
),
decl,
)
def _gen_code_list_type(
self, arg_name: str, out_name: str, t: ListType, ctype: CType
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem"
code = []
res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
t.elem, elem_name
)
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and (
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
):
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
# handle list type with size, e.g., bool[4]
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
""".split(
"\n"
)
)
# pytorch codegen:
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
elif (
isinstance(t.elem, OptionalType)
and isinstance(t.elem.elem, BaseType)
and t.elem.elem.name == BaseTy.Tensor
):
code.extend(
f"""
#ifdef USE_ATEN_LIB
auto {in_name} = {arg_name}.toListOptionalTensor();
c10::List<::std::optional<at::Tensor>> {out_name};
for (auto {elem_name}: {in_name}) {{
{out_name}.push_back({elem_name});
}}
#else
torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor();
#endif
""".split(
"\n"
)
)
else:
# use ArrayRef as default.
vec_name = arg_name + "_vec"
# need to bring vector instantiation out of scope so that ArrayRef has valid data
decl.append(
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
)
code.extend(
f"""
for (EValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
)
return code, decl
|