Spaces:
Running
Running
File size: 6,846 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
# Functions for synthesizing magic methods for JIT-compiled dataclasses
import ast
import dataclasses
import inspect
import os
from functools import partial
from typing import Callable, Dict, List
from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional
from torch._sources import ParsedDef, SourceContext
def _get_fake_filename(cls, method_name):
return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
body = "\n".join(f" {b}" for b in body_lines)
decl = f"def {name}{signature}:\n{body}"
# Parse the function declaration
try:
py_ast = ast.parse(decl)
except SyntaxError as e:
# This should only happen if there's some unforeseeable change
# in the dataclasses module that makes our synthesized code fail
raise RuntimeError(
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
) from e
fake_filename = _get_fake_filename(cls, name)
# Parse the function
return ParsedDef(
py_ast,
ctx=SourceContext(
source=decl, filename=fake_filename, file_lineno=0, leading_whitespace_len=0
),
source=decl,
filename=fake_filename,
file_lineno=0,
)
def synthesize__init__(cls) -> ParsedDef:
# Supporting default factories in the way that people expect would sort of require us to
# allow compiling lambda functions, which is not currently supported.
if any(
field.default_factory is not dataclasses.MISSING
for field in dataclasses.fields(cls)
):
raise NotImplementedError(
"Default factory initializers are not supported in TorchScript dataclasses"
)
# Simply read off the generated __init__ signature from CPython's implementation. It'll be
# almost correct except for InitVar annotations, which we need to handle specially.
signature = inspect.signature(cls.__init__)
# Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
# see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
init_vars: List[str] = []
params = []
for name, param in signature.parameters.items():
ann = param.annotation
if isinstance(ann, dataclasses.InitVar):
# The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
init_vars.append(name)
params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined]
else:
params.append(param)
signature = signature.replace(parameters=params)
body = [
# Assign all attributes to self
f"self.{field.name} = {field.name}"
for field in dataclasses.fields(cls)
if field.init and field.name not in init_vars
]
# Call user's impl of __post_init__ if it exists
if hasattr(cls, "__post_init__"):
body.append("self.__post_init__(" + ", ".join(init_vars) + ")")
return compose_fn(cls, "__init__", body or ["pass"], signature=str(signature))
# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
def synthesize__repr__(cls) -> ParsedDef:
return compose_fn(
cls,
"__repr__",
[
f"return '{cls.__name__}("
+ ", ".join(
[
f"{field.name}=self.{field.name}"
for field in dataclasses.fields(cls)
if field.repr
]
)
+ ")'"
],
signature="(self) -> str",
)
def synthesize__hash__(cls) -> ParsedDef:
return compose_fn(
cls,
"__hash__",
[
# This is just a placeholder to prevent compilation from failing; this won't even get called at
# all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
"raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
],
signature="(self) -> int",
)
# Implementation for __eq__ and __ne__
def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
return synthesize_comparison(
cls,
name,
allow_eq=True,
raise_on_none=False,
inner=[f"if val1 {converse} val2: return False"],
)
def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
return synthesize_comparison(
cls,
name,
allow_eq,
raise_on_none=True,
inner=[
f"if val1 {op} val2: return True",
f"elif val2 {op} val1: return False",
],
)
def synthesize_comparison(
cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]
) -> ParsedDef:
body = []
for field in dataclasses.fields(cls):
if not field.compare:
continue
body.extend(
[
f"val1 = self.{field.name}",
f"val2 = other.{field.name}",
]
)
body.extend(
inner
if not is_optional(field.type)
else [
# Type refinement for optional fields; we need this to avoid type errors from the interpreter
"if val1 is not None and val2 is not None:",
*[" " + line for line in inner],
"elif (val1 is None) != (val2 is None):",
f" raise TypeError('Cannot compare {cls.__name__} with None')"
if raise_on_none
else " return False",
]
)
body.append(f"return {allow_eq}")
return compose_fn(
cls, name, body, signature=f"(self, other: {cls.__name__}) -> bool"
)
DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
"__init__": synthesize__init__,
"__repr__": synthesize__repr__,
"__hash__": synthesize__hash__,
"__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
"__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
"__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
"__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
"__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
"__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
}
|