| # Owner(s): ["oncall: export"] |
| from torch._export.serde.schema_check import ( |
| _Commit, |
| _diff_schema, |
| check, |
| SchemaUpdateError, |
| update_schema, |
| ) |
| |
| from torch.testing._internal.common_utils import IS_FBCODE, run_tests, TestCase |
| |
| |
| class TestSchema(TestCase): |
| def test_schema_compatibility(self): |
| msg = """ |
| Detected an invalidated change to export schema. Please run the following script to update the schema: |
| Example(s): |
| python scripts/export/update_schema.py --prefix <path_to_torch_development_diretory> |
| """ |
| |
| if IS_FBCODE: |
| msg += """or |
| buck run caffe2:export_update_schema -- --prefix /data/users/$USER/fbsource/fbcode/caffe2/ |
| """ |
| try: |
| commit = update_schema() |
| except SchemaUpdateError as e: |
| self.fail(f"Failed to update schema: {e}\n{msg}") |
| |
| self.assertEqual(commit.checksum_base, commit.checksum_result, msg) |
| |
| def test_schema_diff(self): |
| additions, subtractions = _diff_schema( |
| { |
| "Type0": {"kind": "struct", "fields": {}}, |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| "field2": {"type": ""}, |
| "field3": {"type": "", "default": "[]"}, |
| }, |
| }, |
| }, |
| { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field1": {"type": "", "default": "0"}, |
| "field2": {"type": "", "default": "[]"}, |
| "field3": {"type": ""}, |
| }, |
| }, |
| "Type1": {"kind": "struct", "fields": {}}, |
| }, |
| ) |
| |
| self.assertEqual( |
| additions, |
| { |
| "Type1": {"kind": "struct", "fields": {}}, |
| "Type2": { |
| "fields": { |
| "field1": {"type": "", "default": "0"}, |
| "field2": {"default": "[]"}, |
| }, |
| }, |
| }, |
| ) |
| self.assertEqual( |
| subtractions, |
| { |
| "Type0": {"kind": "struct", "fields": {}}, |
| "Type2": { |
| "fields": { |
| "field0": {"type": ""}, |
| "field3": {"default": "[]"}, |
| }, |
| }, |
| }, |
| ) |
| |
| def test_schema_check(self): |
| # Adding field without default value |
| dst = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| src = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| "field1": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| |
| additions, subtractions = _diff_schema(dst, src) |
| |
| commit = _Commit( |
| result=src, |
| checksum_result="", |
| path="", |
| additions=additions, |
| subtractions=subtractions, |
| base=dst, |
| checksum_base="", |
| ) |
| next_version, _ = check(commit) |
| self.assertEqual(next_version, [4, 1]) |
| |
| # Removing field |
| dst = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| src = { |
| "Type2": { |
| "kind": "struct", |
| "fields": {}, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| |
| additions, subtractions = _diff_schema(dst, src) |
| |
| commit = _Commit( |
| result=src, |
| checksum_result="", |
| path="", |
| additions=additions, |
| subtractions=subtractions, |
| base=dst, |
| checksum_base="", |
| ) |
| next_version, _ = check(commit) |
| self.assertEqual(next_version, [4, 1]) |
| |
| # Adding field with default value |
| dst = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| src = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| "field1": {"type": "", "default": "[]"}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| |
| additions, subtractions = _diff_schema(dst, src) |
| |
| commit = _Commit( |
| result=src, |
| checksum_result="", |
| path="", |
| additions=additions, |
| subtractions=subtractions, |
| base=dst, |
| checksum_base="", |
| ) |
| next_version, _ = check(commit) |
| self.assertEqual(next_version, [3, 3]) |
| |
| # Changing field type |
| dst = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| src = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": "int"}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| |
| with self.assertRaises(SchemaUpdateError): |
| _diff_schema(dst, src) |
| |
| # Adding new type. |
| dst = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| src = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| }, |
| }, |
| "Type1": {"kind": "struct", "fields": {}}, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| |
| additions, subtractions = _diff_schema(dst, src) |
| |
| commit = _Commit( |
| result=src, |
| checksum_result="", |
| path="", |
| additions=additions, |
| subtractions=subtractions, |
| base=dst, |
| checksum_base="", |
| ) |
| next_version, _ = check(commit) |
| self.assertEqual(next_version, [3, 3]) |
| |
| # Removing a type. |
| dst = { |
| "Type2": { |
| "kind": "struct", |
| "fields": { |
| "field0": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| src = { |
| "SCHEMA_VERSION": [3, 2], |
| } |
| |
| additions, subtractions = _diff_schema(dst, src) |
| |
| commit = _Commit( |
| result=src, |
| checksum_result="", |
| path="", |
| additions=additions, |
| subtractions=subtractions, |
| base=dst, |
| checksum_base="", |
| ) |
| next_version, _ = check(commit) |
| self.assertEqual(next_version, [3, 3]) |
| |
| # Adding new field in union. |
| dst = { |
| "Type2": { |
| "kind": "union", |
| "fields": { |
| "field0": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| src = { |
| "Type2": { |
| "kind": "union", |
| "fields": { |
| "field0": {"type": ""}, |
| "field1": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| |
| additions, subtractions = _diff_schema(dst, src) |
| |
| commit = _Commit( |
| result=src, |
| checksum_result="", |
| path="", |
| additions=additions, |
| subtractions=subtractions, |
| base=dst, |
| checksum_base="", |
| ) |
| next_version, _ = check(commit) |
| self.assertEqual(next_version, [3, 3]) |
| |
| # Removing a field in union. |
| dst = { |
| "Type2": { |
| "kind": "union", |
| "fields": { |
| "field0": {"type": ""}, |
| }, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| src = { |
| "Type2": { |
| "kind": "union", |
| "fields": {}, |
| }, |
| "SCHEMA_VERSION": [3, 2], |
| } |
| |
| additions, subtractions = _diff_schema(dst, src) |
| |
| commit = _Commit( |
| result=src, |
| checksum_result="", |
| path="", |
| additions=additions, |
| subtractions=subtractions, |
| base=dst, |
| checksum_base="", |
| ) |
| next_version, _ = check(commit) |
| self.assertEqual(next_version, [4, 1]) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |