blob: 0c0120397eea3b5ce0c48c81fabfa9973cdfb7ee [file] [log] [blame]
# Owner(s): ["module: pytree"]
import unittest
from collections import namedtuple, OrderedDict
import torch
import torch.utils._cxx_pytree as cxx_pytree
import torch.utils._pytree as py_pytree
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
GlobalPoint = namedtuple("GlobalPoint", ["x", "y"])
class GlobalDummyType:
def __init__(self, x, y):
self.x = x
self.y = y
class TestGenericPytree(TestCase):
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_leaf(self, pytree_impl):
def run_test_with_leaf(leaf):
values, treespec = pytree_impl.tree_flatten(leaf)
self.assertEqual(values, [leaf])
self.assertEqual(treespec, pytree_impl.LeafSpec())
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, leaf)
run_test_with_leaf(1)
run_test_with_leaf(1.0)
run_test_with_leaf(None)
run_test_with_leaf(bool)
run_test_with_leaf(torch.randn(3, 3))
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda lst: py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec() for _ in lst]
),
),
name="py",
),
subtest(
(cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
name="cxx",
),
],
)
def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
def run_test(lst):
expected_spec = gen_expected_fn(lst)
values, treespec = pytree_impl.tree_flatten(lst)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, lst)
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, lst)
self.assertTrue(isinstance(unflattened, list))
run_test([])
run_test([1.0, 2])
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda tup: py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec() for _ in tup]
),
),
name="py",
),
subtest(
(cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
name="cxx",
),
],
)
def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
def run_test(tup):
expected_spec = gen_expected_fn(tup)
values, treespec = pytree_impl.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, tuple))
run_test(())
run_test((1.0,))
run_test((1.0, 2))
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda dct: py_pytree.TreeSpec(
dict,
list(dct.keys()),
[py_pytree.LeafSpec() for _ in dct.values()],
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)),
),
name="cxx",
),
],
)
def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn):
def run_test(dct):
expected_spec = gen_expected_fn(dct)
values, treespec = pytree_impl.tree_flatten(dct)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(dct.values()))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, dct)
self.assertTrue(isinstance(unflattened, dict))
run_test({})
run_test({"a": 1})
run_test({"abcdefg": torch.randn(2, 3)})
run_test({1: torch.randn(2, 3)})
run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda odict: py_pytree.TreeSpec(
OrderedDict,
list(odict.keys()),
[py_pytree.LeafSpec() for _ in odict.values()],
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda odict: cxx_pytree.tree_structure(
OrderedDict.fromkeys(odict, 0)
),
),
name="cxx",
),
],
)
def test_flatten_unflatten_odict(self, pytree_impl, gen_expected_fn):
def run_test(odict):
expected_spec = gen_expected_fn(odict)
values, treespec = pytree_impl.tree_flatten(odict)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(odict.values()))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, odict)
self.assertTrue(isinstance(unflattened, OrderedDict))
od = OrderedDict()
run_test(od)
od["b"] = 1
od["a"] = torch.tensor(3.14)
run_test(od)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_namedtuple(self, pytree_impl):
Point = namedtuple("Point", ["x", "y"])
def run_test(tup):
if pytree_impl is py_pytree:
expected_spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec() for _ in tup]
)
else:
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
values, treespec = pytree_impl.tree_flatten(tup)
self.assertTrue(isinstance(values, list))
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertTrue(isinstance(unflattened, Point))
run_test(Point(1.0, 2))
run_test(Point(torch.tensor(1.0), 2))
@parametrize(
"op",
[
subtest(torch.max, name="max"),
subtest(torch.min, name="min"),
],
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_return_type(self, pytree_impl, op):
x = torch.randn(3, 3)
expected = op(x, dim=0)
values, spec = pytree_impl.tree_flatten(expected)
# Check that values is actually List[Tensor] and not (ReturnType(...),)
for value in values:
self.assertTrue(isinstance(value, torch.Tensor))
result = pytree_impl.tree_unflatten(values, spec)
self.assertEqual(type(result), type(expected))
self.assertEqual(result, expected)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_nested(self, pytree_impl):
def run_test(pytree):
values, treespec = pytree_impl.tree_flatten(pytree)
self.assertTrue(isinstance(values, list))
self.assertEqual(len(values), treespec.num_leaves)
# NB: python basic data structures (dict list tuple) all have
# contents equality defined on them, so the following works for them.
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, pytree)
cases = [
[()],
([],),
{"a": ()},
{"a": 0, "b": [{"c": 1}]},
{"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
]
for case in cases:
run_test(case)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_treemap(self, pytree_impl):
def run_test(pytree):
def f(x):
return x * 3
sm1 = sum(map(f, pytree_impl.tree_leaves(pytree)))
sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree)))
self.assertEqual(sm1, sm2)
def invf(x):
return x // 3
self.assertEqual(
pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)),
pytree,
)
cases = [
[()],
([],),
{"a": ()},
{"a": 1, "b": [{"c": 2}]},
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
]
for case in cases:
run_test(case)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_tree_only(self, pytree_impl):
self.assertEqual(
pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_tree_all_any(self, pytree_impl):
self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3]))
self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1]))
self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1]))
self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2]))
self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_broadcast_to_and_flatten(self, pytree_impl):
cases = [
(1, (), []),
# Same (flat) structures
((1,), (0,), [1]),
([1], [0], [1]),
((1, 2, 3), (0, 0, 0), [1, 2, 3]),
({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
# Mismatched (flat) structures
([1], (0,), None),
([1], (0,), None),
((1,), [0], None),
((1, 2, 3), (0, 0), None),
({"a": 1, "b": 2}, {"a": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
# Same (nested) structures
((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
# Mismatched (nested) structures
((1, [2, 3]), (0, (0, 0)), None),
((1, [2, 3]), (0, [0, 0, 0]), None),
# Broadcasting single value
(1, (0, 0, 0), [1, 1, 1]),
(1, [0, 0, 0], [1, 1, 1]),
(1, {"a": 0, "b": 0}, [1, 1]),
(1, (0, [0, [0]], 0), [1, 1, 1, 1]),
(1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
# Broadcast multiple things
((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
(([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
]
for pytree, to_pytree, expected in cases:
_, to_spec = pytree_impl.tree_flatten(to_pytree)
result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec)
self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_pytree_serialize_bad_input(self, pytree_impl):
with self.assertRaises(TypeError):
pytree_impl.treespec_dumps("random_blurb")
class TestPythonPytree(TestCase):
def test_treespec_equality(self):
self.assertTrue(
py_pytree.LeafSpec() == py_pytree.LeafSpec(),
)
self.assertTrue(
py_pytree.TreeSpec(list, None, []) == py_pytree.TreeSpec(list, None, []),
)
self.assertTrue(
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()])
== py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
)
self.assertFalse(
py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
)
self.assertTrue(
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
def test_treespec_repr(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = py_pytree.tree_flatten(pytree)
self.assertEqual(
repr(spec),
(
"TreeSpec(tuple, None, [*,\n"
" TreeSpec(list, None, [*,\n"
" *,\n"
" TreeSpec(list, None, [*])])])"
),
)
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
def test_treespec_repr_dynamo(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = py_pytree.tree_flatten(pytree)
self.assertExpectedInline(
repr(spec),
"""\
TreeSpec(tuple, None, [*,
TreeSpec(list, None, [*,
*,
TreeSpec(list, None, [*])])])""",
)
@parametrize(
"spec",
[
py_pytree.TreeSpec(list, None, []),
py_pytree.TreeSpec(tuple, None, []),
py_pytree.TreeSpec(dict, [], []),
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.TreeSpec(
tuple,
None,
[py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()],
),
py_pytree.TreeSpec(
dict,
["a", "b", "c"],
[py_pytree.LeafSpec(), py_pytree.LeafSpec(), py_pytree.LeafSpec()],
),
py_pytree.TreeSpec(
OrderedDict,
["a", "b", "c"],
[
py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
dict,
["a", "b", "c"],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
),
py_pytree.TreeSpec(
list,
None,
[
py_pytree.TreeSpec(
tuple,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
),
],
),
],
)
def test_pytree_serialize(self, spec):
serialized_spec = py_pytree.treespec_dumps(spec)
self.assertTrue(isinstance(serialized_spec, str))
self.assertTrue(spec == py_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
Point = namedtuple("Point", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
# The context in the namedtuple is different now because we recreated
# the namedtuple type.
self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)
@unittest.expectedFailure
def test_pytree_custom_type_serialize_bad(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(
NotImplementedError, "No registered serialization name"
):
roundtrip_spec = py_pytree.treespec_dumps(spec)
def test_pytree_custom_type_serialize(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
serialized_type_name="test_pytree_custom_type_serialize.DummyType",
to_dumpable_context=lambda context: "moo",
from_dumpable_context=lambda dumpable_context: None,
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
serialized_spec = py_pytree.treespec_dumps(spec, 1)
self.assertTrue("moo" in serialized_spec)
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
def test_pytree_serialize_register_bad(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
with self.assertRaisesRegex(
ValueError, "Both to_dumpable_context and from_dumpable_context"
):
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
serialized_type_name="test_pytree_serialize_register_bad.DummyType",
to_dumpable_context=lambda context: "moo",
)
def test_pytree_context_serialize_bad(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
serialized_type_name="test_pytree_serialize_serialize_bad.DummyType",
to_dumpable_context=lambda context: DummyType,
from_dumpable_context=lambda dumpable_context: None,
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(
TypeError, "Object of type type is not JSON serializable"
):
py_pytree.treespec_dumps(spec)
def test_pytree_serialize_bad_protocol(self):
import json
Point = namedtuple("Point", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
py_pytree.treespec_dumps(spec, -1)
serialized_spec = py_pytree.treespec_dumps(spec)
protocol, data = json.loads(serialized_spec)
bad_protocol_serialized_spec = json.dumps((-1, data))
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
py_pytree.treespec_loads(bad_protocol_serialized_spec)
def test_saved_serialized(self):
complicated_spec = py_pytree.TreeSpec(
OrderedDict,
[1, 2, 3],
[
py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
dict,
[4, 5, 6],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
)
serialized_spec = py_pytree.treespec_dumps(complicated_spec)
saved_spec = (
'[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", '
'"children_spec": [{"type": "builtins.tuple", "context": "null", '
'"children_spec": [{"type": null, "context": null, '
'"children_spec": []}, {"type": null, "context": null, '
'"children_spec": []}]}, {"type": null, "context": null, '
'"children_spec": []}, {"type": "builtins.dict", "context": '
'"[4, 5, 6]", "children_spec": [{"type": null, "context": null, '
'"children_spec": []}, {"type": null, "context": null, "children_spec": '
'[]}, {"type": null, "context": null, "children_spec": []}]}]}]'
)
self.assertEqual(serialized_spec, saved_spec)
self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec))
class TestCxxPytree(TestCase):
def test_treespec_equality(self):
self.assertTrue(cxx_pytree.LeafSpec() == cxx_pytree.LeafSpec())
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
def test_treespec_repr(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = cxx_pytree.tree_flatten(pytree)
self.assertEqual(
repr(spec),
("PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)"),
)
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
def test_treespec_repr_dynamo(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = cxx_pytree.tree_flatten(pytree)
self.assertExpectedInline(
repr(spec),
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
)
@parametrize(
"spec",
[
cxx_pytree.tree_structure([]),
cxx_pytree.tree_structure(()),
cxx_pytree.tree_structure({}),
cxx_pytree.tree_structure([0]),
cxx_pytree.tree_structure([0, 1]),
cxx_pytree.tree_structure((0, 1, 2)),
cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}),
cxx_pytree.tree_structure(
OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
),
cxx_pytree.tree_structure([(0, 1, [2, 3])]),
],
)
def test_pytree_serialize(self, spec):
serialized_spec = cxx_pytree.treespec_dumps(spec)
self.assertTrue(isinstance(serialized_spec, str))
self.assertTrue(spec == cxx_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
LocalPoint = namedtuple("LocalPoint", ["x", "y"])
spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
def test_pytree_custom_type_serialize(self):
cxx_pytree.register_pytree_node(
GlobalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: GlobalDummyType(*xs),
serialized_type_name="GlobalDummyType",
)
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
class LocalDummyType:
def __init__(self, x, y):
self.x = x
self.y = y
cxx_pytree.register_pytree_node(
LocalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: LocalDummyType(*xs),
serialized_type_name="LocalDummyType",
)
spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
instantiate_parametrized_tests(TestGenericPytree)
instantiate_parametrized_tests(TestPythonPytree)
instantiate_parametrized_tests(TestCxxPytree)
if __name__ == "__main__":
run_tests()