| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import dataclasses |
| import hashlib |
| import re |
| import typing |
| from enum import IntEnum |
| from typing import Any, Dict, Optional, Union |
| |
| from torch._export.serde import schema |
| from torch._export.serde.union import _Union |
| |
| |
| class SchemaUpdateError(Exception): |
| pass |
| |
| |
| def _check(x, msg): |
| if not x: |
| raise SchemaUpdateError(msg) |
| |
| |
| def _staged_schema(): |
| ret: Dict[str, Any] = {} |
| defs = {} |
| |
| def _handle_aggregate(ty): |
| def dump_type(t): |
| if isinstance(t, type): |
| return t.__name__ |
| elif isinstance(t, str): |
| assert t in defs |
| return t |
| elif o := typing.get_origin(t): |
| # Lemme know if there's a better way to do this. |
| if o == list: |
| head = "List" |
| elif o == dict: |
| head = "Dict" |
| elif o == tuple: |
| if typing.get_args(t) == (): |
| return "Tuple[()]" |
| head = "Tuple" |
| elif o == Union: |
| args = typing.get_args(t) |
| assert len(args) == 2 and args[1] == type(None) |
| return f"Optional[{dump_type(args[0])}]" |
| else: |
| raise AssertionError(f"Type {t} is not supported in export schema.") |
| return ( |
| f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]" |
| ) |
| elif t == (): |
| return "()" |
| else: |
| raise AssertionError(f"Type {t} is not supported in export schema.") |
| |
| def dump_field(f): |
| t = dump_type(f.type) |
| ret = {"type": t} |
| |
| value = dataclasses.MISSING |
| if f.default is not dataclasses.MISSING: |
| value = f.default |
| elif f.default_factory is not dataclasses.MISSING: |
| value = f.default_factory() |
| |
| if t.startswith("Optional[") and value is not None: |
| raise AssertionError( |
| f"Optional field {ty.__name__}.{f.name} must have default value to be None." |
| ) |
| |
| if value is not dataclasses.MISSING: |
| default = str(value) |
| ret["default"] = default |
| return ret |
| |
| return {f.name: dump_field(f) for f in dataclasses.fields(ty)} |
| |
| def _handle_int_enum(name, ty): |
| ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} |
| |
| def _handle_struct(name, ty): |
| ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)} |
| |
| def _handle_union(name, ty): |
| ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)} |
| |
| for name in dir(schema): |
| if name.startswith("_"): |
| continue |
| |
| value = getattr(schema, name) |
| |
| if hasattr(value, "__module__") and value.__module__ != schema.__name__: |
| continue |
| |
| defs[name] = value |
| |
| for name, value in defs.items(): |
| if isinstance(value, type): |
| if issubclass(value, IntEnum): |
| _handle_int_enum(name, value) |
| elif dataclasses.is_dataclass(value): |
| if issubclass(value, _Union): |
| _handle_union(name, value) |
| else: |
| _handle_struct(name, value) |
| else: |
| raise AssertionError(f"Unknown schema type {name}: {value}") |
| elif isinstance(value, (int, tuple)): |
| assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION") |
| else: |
| raise AssertionError(f"Unknown variable {name}: {value}") |
| |
| ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) |
| assert all(x > 0 for x in ret["SCHEMA_VERSION"]) |
| ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] |
| assert ret["TREESPEC_VERSION"] > 0 |
| return ret |
| |
| |
| def _diff_schema(dst, src): |
| additions = {key: src[key] for key in src.keys() - dst.keys()} |
| subtractions = {key: dst[key] for key in dst.keys() - src.keys()} |
| |
| common_keys = src.keys() & dst.keys() |
| |
| versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"} |
| common_keys -= versions |
| |
| for key in common_keys: |
| src_kind = src[key]["kind"] |
| src_fields = src[key]["fields"] |
| dst_kind = dst[key]["kind"] |
| dst_fields = dst[key]["fields"] |
| _check( |
| src_kind == dst_kind, |
| f"Type {key} changed kind from {dst_kind} to {src_kind}", |
| ) |
| assert isinstance(src_fields, dict) and isinstance(dst_fields, dict) |
| added_fields = { |
| key: src_fields[key] for key in src_fields.keys() - dst_fields.keys() |
| } |
| subtracted_fields = { |
| key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys() |
| } |
| common_fields = src_fields.keys() & dst_fields.keys() |
| |
| for field in common_fields: |
| src_field = src_fields[field] |
| dst_field = dst_fields[field] |
| if src_kind == "struct": |
| _check( |
| src_field["type"] == dst_field["type"], |
| f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", |
| ) |
| if "default" in src_field and "default" not in dst_field: |
| added_fields[field] = {} |
| added_fields[field]["default"] = src_field["default"] |
| if "default" not in src_field and "default" in dst_field: |
| subtracted_fields[field] = {} |
| subtracted_fields[field]["default"] = dst_field["default"] |
| elif src_kind == "enum": |
| _check( |
| src_field == dst_field, |
| f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}", |
| ) |
| elif src_kind == "union": |
| _check( |
| src_field["type"] == dst_field["type"], |
| f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", |
| ) |
| else: |
| raise AssertionError(f"Unknown kind {src_kind}: {key}") |
| if len(added_fields) > 0: |
| assert key not in additions |
| additions[key] = {} |
| additions[key]["fields"] = added_fields |
| if len(subtracted_fields) > 0: |
| assert key not in subtractions |
| subtractions[key] = {} |
| subtractions[key]["fields"] = subtracted_fields |
| |
| return additions, subtractions |
| |
| |
| def _hash_schema(s): |
| return hashlib.sha256(repr(s).encode("utf-8")).hexdigest() |
| |
| |
| @dataclasses.dataclass |
| class _Commit: |
| result: Dict[str, Any] |
| checksum_result: str |
| path: str |
| additions: Dict[str, Any] |
| subtractions: Dict[str, Any] |
| base: Dict[str, Any] |
| checksum_base: Optional[str] |
| |
| |
| def update_schema(): |
| import importlib.resources |
| |
| if importlib.resources.is_resource(__package__, "schema.yaml"): |
| content = importlib.resources.read_text(__package__, "schema.yaml") |
| match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) |
| _check(match is not None, "checksum not found in schema.yaml") |
| assert match is not None |
| checksum_base = match.group(1) |
| from yaml import load, Loader |
| |
| dst = load(content, Loader=Loader) |
| assert isinstance(dst, dict) |
| else: |
| checksum_base = None |
| dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} |
| |
| src = _staged_schema() |
| additions, subtractions = _diff_schema(dst, src) |
| return _Commit( |
| result=src, |
| checksum_result=_hash_schema(src), |
| path=__package__.replace(".", "/") + "/schema.yaml", |
| additions=additions, |
| subtractions=subtractions, |
| base=dst, |
| checksum_base=checksum_base, |
| ) |
| |
| |
| def check(commit: _Commit, force_unsafe: bool = False): |
| next_version = None |
| reason = "" |
| # Step 1: Detect major schema updates. |
| if len(commit.additions) > 0: |
| for k, v in commit.additions.items(): |
| if k not in commit.base: |
| continue |
| kind = commit.result[k]["kind"] |
| fields = v["fields"] |
| for f, d in fields.items(): |
| if "default" not in d and kind == "struct": |
| reason += ( |
| f"Field {k}.{f} is added to schema.py without a default value as an incomparible change " |
| + "which requires major version bump.\n" |
| ) |
| next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] |
| |
| if len(commit.subtractions) > 0: |
| for k, v in commit.subtractions.items(): |
| if k not in commit.result: |
| continue |
| for f in v["fields"]: |
| reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n" |
| next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] |
| |
| if force_unsafe: |
| reason += "--force-unsafe is used." |
| next_version = commit.result["SCHEMA_VERSION"] |
| else: |
| # Step 2: Detect minor schema updates. |
| if next_version is None and len(commit.additions) > 0: |
| for k, v in commit.additions.items(): |
| for f in v["fields"]: |
| reason += ( |
| f"Field {k}.{f} is added to schema.py as an compatible change " |
| + "which still requires minor version bump.\n" |
| ) |
| next_version = [ |
| commit.base["SCHEMA_VERSION"][0], |
| commit.base["SCHEMA_VERSION"][1] + 1, |
| ] |
| if next_version is None and len(commit.subtractions) > 0: |
| for k, v in commit.subtractions.items(): |
| for f in v["fields"]: |
| reason += ( |
| f"Field {k}.{f} is removed from schema.py as an compatible change " |
| + "which still requires minor version bump.\n" |
| ) |
| next_version = [ |
| commit.base["SCHEMA_VERSION"][0], |
| commit.base["SCHEMA_VERSION"][1] + 1, |
| ] |
| |
| return next_version, reason |