Spaces:
Running
Running
# Functions for synthesizing magic methods for JIT-compiled dataclasses | |
import ast | |
import dataclasses | |
import inspect | |
import os | |
from functools import partial | |
from typing import Callable, Dict, List | |
from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional | |
from torch._sources import ParsedDef, SourceContext | |
def _get_fake_filename(cls, method_name): | |
return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name) | |
def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef: | |
body = "\n".join(f" {b}" for b in body_lines) | |
decl = f"def {name}{signature}:\n{body}" | |
# Parse the function declaration | |
try: | |
py_ast = ast.parse(decl) | |
except SyntaxError as e: | |
# This should only happen if there's some unforeseeable change | |
# in the dataclasses module that makes our synthesized code fail | |
raise RuntimeError( | |
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. " | |
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>" | |
) from e | |
fake_filename = _get_fake_filename(cls, name) | |
# Parse the function | |
return ParsedDef( | |
py_ast, | |
ctx=SourceContext( | |
source=decl, filename=fake_filename, file_lineno=0, leading_whitespace_len=0 | |
), | |
source=decl, | |
filename=fake_filename, | |
file_lineno=0, | |
) | |
def synthesize__init__(cls) -> ParsedDef: | |
# Supporting default factories in the way that people expect would sort of require us to | |
# allow compiling lambda functions, which is not currently supported. | |
if any( | |
field.default_factory is not dataclasses.MISSING | |
for field in dataclasses.fields(cls) | |
): | |
raise NotImplementedError( | |
"Default factory initializers are not supported in TorchScript dataclasses" | |
) | |
# Simply read off the generated __init__ signature from CPython's implementation. It'll be | |
# almost correct except for InitVar annotations, which we need to handle specially. | |
signature = inspect.signature(cls.__init__) | |
# Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar); | |
# see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c | |
init_vars: List[str] = [] | |
params = [] | |
for name, param in signature.parameters.items(): | |
ann = param.annotation | |
if isinstance(ann, dataclasses.InitVar): | |
# The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here | |
init_vars.append(name) | |
params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined] | |
else: | |
params.append(param) | |
signature = signature.replace(parameters=params) | |
body = [ | |
# Assign all attributes to self | |
f"self.{field.name} = {field.name}" | |
for field in dataclasses.fields(cls) | |
if field.init and field.name not in init_vars | |
] | |
# Call user's impl of __post_init__ if it exists | |
if hasattr(cls, "__post_init__"): | |
body.append("self.__post_init__(" + ", ".join(init_vars) + ")") | |
return compose_fn(cls, "__init__", body or ["pass"], signature=str(signature)) | |
# This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__ | |
def synthesize__repr__(cls) -> ParsedDef: | |
return compose_fn( | |
cls, | |
"__repr__", | |
[ | |
f"return '{cls.__name__}(" | |
+ ", ".join( | |
[ | |
f"{field.name}=self.{field.name}" | |
for field in dataclasses.fields(cls) | |
if field.repr | |
] | |
) | |
+ ")'" | |
], | |
signature="(self) -> str", | |
) | |
def synthesize__hash__(cls) -> ParsedDef: | |
return compose_fn( | |
cls, | |
"__hash__", | |
[ | |
# This is just a placeholder to prevent compilation from failing; this won't even get called at | |
# all right now because the TorchScript interpreter doesn't call custom __hash__ implementations | |
"raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')" | |
], | |
signature="(self) -> int", | |
) | |
# Implementation for __eq__ and __ne__ | |
def synthesize_equality(cls, name: str, converse: str) -> ParsedDef: | |
return synthesize_comparison( | |
cls, | |
name, | |
allow_eq=True, | |
raise_on_none=False, | |
inner=[f"if val1 {converse} val2: return False"], | |
) | |
def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef: | |
return synthesize_comparison( | |
cls, | |
name, | |
allow_eq, | |
raise_on_none=True, | |
inner=[ | |
f"if val1 {op} val2: return True", | |
f"elif val2 {op} val1: return False", | |
], | |
) | |
def synthesize_comparison( | |
cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str] | |
) -> ParsedDef: | |
body = [] | |
for field in dataclasses.fields(cls): | |
if not field.compare: | |
continue | |
body.extend( | |
[ | |
f"val1 = self.{field.name}", | |
f"val2 = other.{field.name}", | |
] | |
) | |
body.extend( | |
inner | |
if not is_optional(field.type) | |
else [ | |
# Type refinement for optional fields; we need this to avoid type errors from the interpreter | |
"if val1 is not None and val2 is not None:", | |
*[" " + line for line in inner], | |
"elif (val1 is None) != (val2 is None):", | |
f" raise TypeError('Cannot compare {cls.__name__} with None')" | |
if raise_on_none | |
else " return False", | |
] | |
) | |
body.append(f"return {allow_eq}") | |
return compose_fn( | |
cls, name, body, signature=f"(self, other: {cls.__name__}) -> bool" | |
) | |
DATACLASS_MAGIC_METHODS: Dict[str, Callable] = { | |
"__init__": synthesize__init__, | |
"__repr__": synthesize__repr__, | |
"__hash__": synthesize__hash__, | |
"__eq__": partial(synthesize_equality, name="__eq__", converse="!="), | |
"__ne__": partial(synthesize_equality, name="__ne__", converse="=="), | |
"__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False), | |
"__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True), | |
"__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False), | |
"__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True), | |
} | |