| # Functions for synthesizing magic methods for JIT-compiled dataclasses |
| import os |
| from functools import partial |
| from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX |
| from torch._sources import ParsedDef, SourceContext |
| from typing import Callable, Dict, List |
| import ast |
| import dataclasses |
| import inspect |
| import sys |
| |
| def _get_fake_filename(cls, method_name): |
| return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name) |
| |
| |
| def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef: |
| body = '\n'.join(f' {b}' for b in body_lines) |
| decl = f'def {name}{signature}:\n{body}' |
| |
| # Parse the function declaration |
| try: |
| py_ast = ast.parse(decl) |
| except SyntaxError: |
| # This should only happen if there's some unforeseeable change |
| # in the dataclasses module that makes our synthesized code fail |
| raise RuntimeError( |
| f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. " |
| "Please file a bug report at <https://github.com/pytorch/pytorch/issues>" |
| ) |
| fake_filename = _get_fake_filename(cls, name) |
| # Parse the function |
| return ParsedDef( |
| py_ast, |
| ctx=SourceContext( |
| source=decl, |
| filename=fake_filename, |
| file_lineno=0, |
| leading_whitespace_len=0 |
| ), |
| source=decl, |
| filename=fake_filename, |
| file_lineno=0 |
| ) |
| |
| |
| def synthesize__init__(cls) -> ParsedDef: |
| # Supporting default factories in the way that people expect would sort of require us to |
| # allow compiling lambda functions, which is not currently supported. |
| if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)): |
| raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses") |
| |
| # Simply read off the generated __init__ signature from CPython's implementation. It'll be |
| # almost correct except for InitVar annotations, which we need to handle specially. |
| signature = inspect.signature(cls.__init__) |
| |
| # Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar); |
| # see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c |
| init_vars: List[str] = [] |
| if sys.version_info >= (3, 8): |
| params = [] |
| for name, param in signature.parameters.items(): |
| ann = param.annotation |
| |
| if isinstance(ann, dataclasses.InitVar): |
| # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here |
| init_vars.append(name) |
| params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined] |
| else: |
| params.append(param) |
| |
| signature = signature.replace(parameters=params) |
| |
| body = [ |
| # Assign all attributes to self |
| f'self.{field.name} = {field.name}' |
| for field in dataclasses.fields(cls) |
| if field.init and field.name not in init_vars |
| ] |
| # Call user's impl of __post_init__ if it exists |
| if hasattr(cls, '__post_init__'): |
| body.append('self.__post_init__(' + ', '.join(init_vars) + ')') |
| |
| return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature)) |
| |
| # This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__ |
| def synthesize__repr__(cls) -> ParsedDef: |
| return compose_fn( |
| cls, '__repr__', |
| [f"return '{cls.__name__}(" + ", ".join([ |
| f"{field.name}=self.{field.name}" |
| for field in dataclasses.fields(cls) if field.repr |
| ]) + ")'"], |
| signature='(self) -> str' |
| ) |
| |
| def synthesize__hash__(cls) -> ParsedDef: |
| return compose_fn( |
| cls, '__hash__', |
| [ |
| # This is just a placeholder to prevent compilation from failing; this won't even get called at |
| # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations |
| "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')" |
| ], |
| signature='(self) -> int' |
| ) |
| |
| # Implementation for __eq__ and __ne__ |
| def synthesize_equality(cls, name: str, converse: str) -> ParsedDef: |
| return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[ |
| f"if val1 {converse} val2: return False" |
| ]) |
| |
| def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef: |
| return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[ |
| f"if val1 {op} val2: return True", |
| f"elif val2 {op} val1: return False", |
| ]) |
| |
| def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef: |
| body = [] |
| for field in dataclasses.fields(cls): |
| if not field.compare: |
| continue |
| |
| body.extend([ |
| f"val1 = self.{field.name}", |
| f"val2 = other.{field.name}", |
| ]) |
| body.extend( |
| inner if not is_optional(field.type) else [ |
| # Type refinement for optional fields; we need this to avoid type errors from the interpreter |
| "if val1 is not None and val2 is not None:", |
| *[' ' + line for line in inner], |
| "elif (val1 is None) != (val2 is None):", |
| f" raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else " return False" |
| ] |
| ) |
| |
| body.append(f"return {allow_eq}") |
| return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool') |
| |
| DATACLASS_MAGIC_METHODS: Dict[str, Callable] = { |
| "__init__": synthesize__init__, |
| "__repr__": synthesize__repr__, |
| "__hash__": synthesize__hash__, |
| "__eq__": partial(synthesize_equality, name="__eq__", converse="!="), |
| "__ne__": partial(synthesize_equality, name="__ne__", converse="=="), |
| "__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False), |
| "__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True), |
| "__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False), |
| "__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True), |
| } |