blob: cd0357dac32fcefb2d5aa77181cb399dbcbcc7f1 [file] [log] [blame] [edit]
# Owner(s): ["module: pytree"]
import collections
import inspect
import os
import re
import subprocess
import sys
import unittest
from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass
from typing import Any, NamedTuple
import torch
import torch.utils._pytree as py_pytree
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_FBCODE,
parametrize,
run_tests,
skipIfTorchDynamo,
subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
if IS_FBCODE:
# optree is not yet enabled in fbcode, so just re-test the python implementation
cxx_pytree = py_pytree
else:
import torch.utils._cxx_pytree as cxx_pytree
GlobalPoint = namedtuple("GlobalPoint", ["x", "y"])
class GlobalDummyType:
def __init__(self, x, y):
self.x = x
self.y = y
class TestGenericPytree(TestCase):
def test_aligned_public_apis(self):
public_apis = py_pytree.__all__
self.assertEqual(public_apis, cxx_pytree.__all__)
for name in public_apis:
cxx_api = getattr(cxx_pytree, name)
py_api = getattr(py_pytree, name)
self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(py_api))
self.assertEqual(inspect.isfunction(cxx_api), inspect.isfunction(py_api))
if inspect.isfunction(cxx_api):
cxx_signature = inspect.signature(cxx_api)
py_signature = inspect.signature(py_api)
# Check the parameter names are the same.
cxx_param_names = list(cxx_signature.parameters)
py_param_names = list(py_signature.parameters)
self.assertEqual(cxx_param_names, py_param_names)
# Check the positional parameters are the same.
cxx_positional_param_names = [
n
for n, p in cxx_signature.parameters.items()
if (
p.kind
in {
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
}
)
]
py_positional_param_names = [
n
for n, p in py_signature.parameters.items()
if (
p.kind
in {
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
}
)
]
self.assertEqual(cxx_positional_param_names, py_positional_param_names)
for py_name, py_param in py_signature.parameters.items():
self.assertIn(py_name, cxx_signature.parameters)
cxx_param = cxx_signature.parameters[py_name]
# Check parameter kinds and default values are the same.
self.assertEqual(cxx_param.kind, py_param.kind)
self.assertEqual(cxx_param.default, py_param.default)
# Check parameter annotations are the same.
if "TreeSpec" in str(cxx_param.annotation):
self.assertIn("TreeSpec", str(py_param.annotation))
self.assertEqual(
re.sub(
r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
"TreeSpec",
str(cxx_param.annotation),
),
re.sub(
r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
"TreeSpec",
str(py_param.annotation),
),
msg=(
f"C++ parameter {cxx_param} "
f"does not match Python parameter {py_param} "
f"for API `{name}`"
),
)
else:
self.assertEqual(
cxx_param.annotation,
py_param.annotation,
msg=(
f"C++ parameter {cxx_param} "
f"does not match Python parameter {py_param} "
f"for API `{name}`"
),
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_register_pytree_node(self, pytree_impl):
class MyDict(UserDict):
pass
d = MyDict(a=1, b=2, c=3)
# Custom types are leaf nodes by default
values, spec = pytree_impl.tree_flatten(d)
self.assertEqual(values, [d])
self.assertIs(values[0], d)
self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
self.assertTrue(spec.is_leaf())
# Register MyDict as a pytree node
pytree_impl.register_pytree_node(
MyDict,
lambda d: (list(d.values()), list(d.keys())),
lambda values, keys: MyDict(zip(keys, values)),
)
values, spec = pytree_impl.tree_flatten(d)
self.assertEqual(values, [1, 2, 3])
self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
# Do not allow registering the same type twice
with self.assertRaisesRegex(ValueError, "already registered"):
pytree_impl.register_pytree_node(
MyDict,
lambda d: (list(d.values()), list(d.keys())),
lambda values, keys: MyDict(zip(keys, values)),
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_leaf(self, pytree_impl):
def run_test_with_leaf(leaf):
values, treespec = pytree_impl.tree_flatten(leaf)
self.assertEqual(values, [leaf])
self.assertEqual(treespec, pytree_impl.LeafSpec())
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, leaf)
run_test_with_leaf(1)
run_test_with_leaf(1.0)
run_test_with_leaf(None)
run_test_with_leaf(bool)
run_test_with_leaf(torch.randn(3, 3))
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda tup: py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec() for _ in tup]
),
),
name="py",
),
subtest(
(cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
name="cxx",
),
],
)
def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
def run_test(tup):
expected_spec = gen_expected_fn(tup)
values, treespec = pytree_impl.tree_flatten(tup)
self.assertIsInstance(values, list)
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertIsInstance(unflattened, tuple)
run_test(())
run_test((1.0,))
run_test((1.0, 2))
run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda lst: py_pytree.TreeSpec(
list, None, [py_pytree.LeafSpec() for _ in lst]
),
),
name="py",
),
subtest(
(cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
name="cxx",
),
],
)
def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
def run_test(lst):
expected_spec = gen_expected_fn(lst)
values, treespec = pytree_impl.tree_flatten(lst)
self.assertIsInstance(values, list)
self.assertEqual(values, lst)
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, lst)
self.assertIsInstance(unflattened, list)
run_test([])
run_test([1.0, 2])
run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda dct: py_pytree.TreeSpec(
dict,
list(dct.keys()),
[py_pytree.LeafSpec() for _ in dct.values()],
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)),
),
name="cxx",
),
],
)
def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn):
def run_test(dct):
expected_spec = gen_expected_fn(dct)
values, treespec = pytree_impl.tree_flatten(dct)
self.assertIsInstance(values, list)
self.assertEqual(values, list(dct.values()))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, dct)
self.assertIsInstance(unflattened, dict)
run_test({})
run_test({"a": 1})
run_test({"abcdefg": torch.randn(2, 3)})
run_test({1: torch.randn(2, 3)})
run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda odict: py_pytree.TreeSpec(
OrderedDict,
list(odict.keys()),
[py_pytree.LeafSpec() for _ in odict.values()],
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda odict: cxx_pytree.tree_structure(
OrderedDict.fromkeys(odict, 0)
),
),
name="cxx",
),
],
)
def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn):
def run_test(odict):
expected_spec = gen_expected_fn(odict)
values, treespec = pytree_impl.tree_flatten(odict)
self.assertIsInstance(values, list)
self.assertEqual(values, list(odict.values()))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, odict)
self.assertIsInstance(unflattened, OrderedDict)
od = OrderedDict()
run_test(od)
od["b"] = 1
od["a"] = torch.tensor(3.14)
run_test(od)
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda ddct: py_pytree.TreeSpec(
defaultdict,
[ddct.default_factory, list(ddct.keys())],
[py_pytree.LeafSpec() for _ in ddct.values()],
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda ddct: cxx_pytree.tree_structure(
defaultdict(ddct.default_factory, dict.fromkeys(ddct, 0))
),
),
name="cxx",
),
],
)
def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn):
def run_test(ddct):
expected_spec = gen_expected_fn(ddct)
values, treespec = pytree_impl.tree_flatten(ddct)
self.assertIsInstance(values, list)
self.assertEqual(values, list(ddct.values()))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, ddct)
self.assertEqual(unflattened.default_factory, ddct.default_factory)
self.assertIsInstance(unflattened, defaultdict)
run_test(defaultdict(list, {}))
run_test(defaultdict(int, {"a": 1}))
run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)}))
run_test(defaultdict(int, {1: torch.randn(2, 3)}))
run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)}))
@parametrize(
"pytree_impl,gen_expected_fn",
[
subtest(
(
py_pytree,
lambda deq: py_pytree.TreeSpec(
deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq]
),
),
name="py",
),
subtest(
(
cxx_pytree,
lambda deq: cxx_pytree.tree_structure(
deque(deq, maxlen=deq.maxlen)
),
),
name="cxx",
),
],
)
def test_flatten_unflatten_deque(self, pytree_impl, gen_expected_fn):
def run_test(deq):
expected_spec = gen_expected_fn(deq)
values, treespec = pytree_impl.tree_flatten(deq)
self.assertIsInstance(values, list)
self.assertEqual(values, list(deq))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, deq)
self.assertEqual(unflattened.maxlen, deq.maxlen)
self.assertIsInstance(unflattened, deque)
run_test(deque([]))
run_test(deque([1.0, 2]))
run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_namedtuple(self, pytree_impl):
Point = namedtuple("Point", ["x", "y"])
def run_test(tup):
if pytree_impl is py_pytree:
expected_spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec() for _ in tup]
)
else:
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
values, treespec = pytree_impl.tree_flatten(tup)
self.assertIsInstance(values, list)
self.assertEqual(values, list(tup))
self.assertEqual(treespec, expected_spec)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, tup)
self.assertIsInstance(unflattened, Point)
run_test(Point(1.0, 2))
run_test(Point(torch.tensor(1.0), 2))
@parametrize(
"op",
[
subtest(torch.max, name="max"),
subtest(torch.min, name="min"),
],
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_return_types(self, pytree_impl, op):
x = torch.randn(3, 3)
expected = op(x, dim=0)
values, spec = pytree_impl.tree_flatten(expected)
# Check that values is actually List[Tensor] and not (ReturnType(...),)
for value in values:
self.assertIsInstance(value, torch.Tensor)
result = pytree_impl.tree_unflatten(values, spec)
self.assertEqual(type(result), type(expected))
self.assertEqual(result, expected)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_unflatten_nested(self, pytree_impl):
def run_test(pytree):
values, treespec = pytree_impl.tree_flatten(pytree)
self.assertIsInstance(values, list)
self.assertEqual(len(values), treespec.num_leaves)
# NB: python basic data structures (dict list tuple) all have
# contents equality defined on them, so the following works for them.
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, pytree)
cases = [
[()],
([],),
{"a": ()},
{"a": 0, "b": [{"c": 1}]},
{"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
]
for case in cases:
run_test(case)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_flatten_with_is_leaf(self, pytree_impl):
def run_test(pytree, one_level_leaves):
values, treespec = pytree_impl.tree_flatten(
pytree, is_leaf=lambda x: x is not pytree
)
self.assertIsInstance(values, list)
self.assertEqual(len(values), treespec.num_nodes - 1)
self.assertEqual(len(values), treespec.num_leaves)
self.assertEqual(len(values), treespec.num_children)
self.assertEqual(values, one_level_leaves)
self.assertEqual(
treespec,
pytree_impl.tree_structure(
pytree_impl.tree_unflatten([0] * treespec.num_leaves, treespec)
),
)
unflattened = pytree_impl.tree_unflatten(values, treespec)
self.assertEqual(unflattened, pytree)
cases = [
([()], [()]),
(([],), [[]]),
({"a": ()}, [()]),
({"a": 0, "b": [{"c": 1}]}, [0, [{"c": 1}]]),
(
{
"a": 0,
"b": [1, {"c": 2}, torch.ones(3)],
"c": (torch.zeros(2, 3), 1),
},
[0, [1, {"c": 2}, torch.ones(3)], (torch.zeros(2, 3), 1)],
),
]
for case in cases:
run_test(*case)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_tree_map(self, pytree_impl):
def run_test(pytree):
def f(x):
return x * 3
sm1 = sum(map(f, pytree_impl.tree_leaves(pytree)))
sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree)))
self.assertEqual(sm1, sm2)
def invf(x):
return x // 3
self.assertEqual(
pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)),
pytree,
)
cases = [
[()],
([],),
{"a": ()},
{"a": 1, "b": [{"c": 2}]},
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
]
for case in cases:
run_test(case)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_tree_map_multi_inputs(self, pytree_impl):
def run_test(pytree):
def f(x, y, z):
return x, [y, (z, 0)]
pytree_x = pytree
pytree_y = pytree_impl.tree_map(lambda x: (x + 1,), pytree)
pytree_z = pytree_impl.tree_map(lambda x: {"a": x * 2, "b": 2}, pytree)
self.assertEqual(
pytree_impl.tree_map(f, pytree_x, pytree_y, pytree_z),
pytree_impl.tree_map(
lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), pytree
),
)
cases = [
[()],
([],),
{"a": ()},
{"a": 1, "b": [{"c": 2}]},
{"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
]
for case in cases:
run_test(case)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_tree_map_only(self, pytree_impl):
self.assertEqual(
pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_tree_map_only_predicate_fn(self, pytree_impl):
self.assertEqual(
pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1]
)
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_tree_all_any(self, pytree_impl):
self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3]))
self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1]))
self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1]))
self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2]))
self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_broadcast_to_and_flatten(self, pytree_impl):
cases = [
(1, (), []),
# Same (flat) structures
((1,), (0,), [1]),
([1], [0], [1]),
((1, 2, 3), (0, 0, 0), [1, 2, 3]),
({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
# Mismatched (flat) structures
([1], (0,), None),
([1], (0,), None),
((1,), [0], None),
((1, 2, 3), (0, 0), None),
({"a": 1, "b": 2}, {"a": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
# Same (nested) structures
((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
# Mismatched (nested) structures
((1, [2, 3]), (0, (0, 0)), None),
((1, [2, 3]), (0, [0, 0, 0]), None),
# Broadcasting single value
(1, (0, 0, 0), [1, 1, 1]),
(1, [0, 0, 0], [1, 1, 1]),
(1, {"a": 0, "b": 0}, [1, 1]),
(1, (0, [0, [0]], 0), [1, 1, 1, 1]),
(1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
# Broadcast multiple things
((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
(([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
]
for pytree, to_pytree, expected in cases:
_, to_spec = pytree_impl.tree_flatten(to_pytree)
result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec)
self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
@parametrize(
"pytree_impl",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_pytree_serialize_bad_input(self, pytree_impl):
with self.assertRaises(TypeError):
pytree_impl.treespec_dumps("random_blurb")
class TestPythonPytree(TestCase):
def test_deprecated_register_pytree_node(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
with self.assertWarnsRegex(
FutureWarning, "torch.utils._pytree._register_pytree_node"
):
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
)
with self.assertWarnsRegex(UserWarning, "already registered"):
py_pytree._register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
)
def test_import_pytree_doesnt_import_optree(self):
# importing torch.utils._pytree shouldn't import optree.
# only importing torch.utils._cxx_pytree should.
script = """
import sys
import torch
import torch.utils._pytree
assert "torch.utils._pytree" in sys.modules
if "torch.utils._cxx_pytree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
if "optree" in sys.modules:
raise RuntimeError("importing torch.utils._pytree should not import optree")
"""
try:
subprocess.check_output(
[sys.executable, "-c", script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),
)
except subprocess.CalledProcessError as e:
self.fail(
msg=(
"Subprocess exception while attempting to run test: "
+ e.output.decode("utf-8")
)
)
def test_treespec_equality(self):
self.assertEqual(
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
)
self.assertEqual(
py_pytree.TreeSpec(list, None, []),
py_pytree.TreeSpec(list, None, []),
)
self.assertEqual(
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
)
self.assertFalse(
py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
)
self.assertTrue(
py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
def test_treespec_repr(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = py_pytree.tree_flatten(pytree)
self.assertEqual(
repr(spec),
(
"TreeSpec(tuple, None, [*,\n"
" TreeSpec(list, None, [*,\n"
" *,\n"
" TreeSpec(list, None, [*])])])"
),
)
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
def test_treespec_repr_dynamo(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = py_pytree.tree_flatten(pytree)
self.assertExpectedInline(
repr(spec),
"""\
TreeSpec(tuple, None, [*,
TreeSpec(list, None, [*,
*,
TreeSpec(list, None, [*])])])""",
)
@parametrize(
"spec",
[
# py_pytree.tree_structure([])
py_pytree.TreeSpec(list, None, []),
# py_pytree.tree_structure(())
py_pytree.TreeSpec(tuple, None, []),
# py_pytree.tree_structure({})
py_pytree.TreeSpec(dict, [], []),
# py_pytree.tree_structure([0])
py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
# py_pytree.tree_structure([0, 1])
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
# py_pytree.tree_structure((0, 1, 2))
py_pytree.TreeSpec(
tuple,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
# py_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
py_pytree.TreeSpec(
dict,
["a", "b", "c"],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
# py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
py_pytree.TreeSpec(
OrderedDict,
["a", "b", "c"],
[
py_pytree.TreeSpec(
tuple,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
dict,
["a", "b", "c"],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
),
# py_pytree.tree_structure([(0, 1, [2, 3])])
py_pytree.TreeSpec(
list,
None,
[
py_pytree.TreeSpec(
tuple,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
),
],
),
# py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}))
py_pytree.TreeSpec(
defaultdict,
[list, ["a", "b", "c"]],
[
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
py_pytree.TreeSpec(
list,
None,
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
py_pytree.TreeSpec(dict, [], []),
],
),
],
)
def test_pytree_serialize(self, spec):
# Ensure that the spec is valid
self.assertEqual(
spec,
py_pytree.tree_structure(
py_pytree.tree_unflatten([0] * spec.num_leaves, spec)
),
)
serialized_spec = py_pytree.treespec_dumps(spec)
self.assertIsInstance(serialized_spec, str)
self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
Point1 = namedtuple("Point1", ["x", "y"])
py_pytree._register_namedtuple(
Point1,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
)
spec = py_pytree.TreeSpec(
namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
self.assertEqual(spec, roundtrip_spec)
class Point2(NamedTuple):
x: int
y: int
py_pytree._register_namedtuple(
Point2,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
)
spec = py_pytree.TreeSpec(
namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
self.assertEqual(spec, roundtrip_spec)
def test_pytree_serialize_namedtuple_bad(self):
DummyType = namedtuple("DummyType", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(
NotImplementedError, "Please register using `_register_namedtuple`"
):
py_pytree.treespec_dumps(spec)
def test_pytree_custom_type_serialize_bad(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
py_pytree.register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(
NotImplementedError, "No registered serialization name"
):
roundtrip_spec = py_pytree.treespec_dumps(spec)
def test_pytree_custom_type_serialize(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
py_pytree.register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
serialized_type_name="test_pytree_custom_type_serialize.DummyType",
to_dumpable_context=lambda context: "moo",
from_dumpable_context=lambda dumpable_context: None,
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
serialized_spec = py_pytree.treespec_dumps(spec, 1)
self.assertIn("moo", serialized_spec)
roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
def test_pytree_serialize_register_bad(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
with self.assertRaisesRegex(
ValueError, "Both to_dumpable_context and from_dumpable_context"
):
py_pytree.register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
serialized_type_name="test_pytree_serialize_register_bad.DummyType",
to_dumpable_context=lambda context: "moo",
)
def test_pytree_context_serialize_bad(self):
class DummyType:
def __init__(self, x, y):
self.x = x
self.y = y
py_pytree.register_pytree_node(
DummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: DummyType(*xs),
serialized_type_name="test_pytree_serialize_serialize_bad.DummyType",
to_dumpable_context=lambda context: DummyType,
from_dumpable_context=lambda dumpable_context: None,
)
spec = py_pytree.TreeSpec(
DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
with self.assertRaisesRegex(
TypeError, "Object of type type is not JSON serializable"
):
py_pytree.treespec_dumps(spec)
def test_pytree_serialize_bad_protocol(self):
import json
Point = namedtuple("Point", ["x", "y"])
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
py_pytree._register_namedtuple(
Point,
serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
)
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
py_pytree.treespec_dumps(spec, -1)
serialized_spec = py_pytree.treespec_dumps(spec)
protocol, data = json.loads(serialized_spec)
bad_protocol_serialized_spec = json.dumps((-1, data))
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
py_pytree.treespec_loads(bad_protocol_serialized_spec)
def test_saved_serialized(self):
# py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})]))
complicated_spec = py_pytree.TreeSpec(
OrderedDict,
[1, 2, 3],
[
py_pytree.TreeSpec(
tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
),
py_pytree.LeafSpec(),
py_pytree.TreeSpec(
dict,
[4, 5, 6],
[
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
py_pytree.LeafSpec(),
],
),
],
)
# Ensure that the spec is valid
self.assertEqual(
complicated_spec,
py_pytree.tree_structure(
py_pytree.tree_unflatten(
[0] * complicated_spec.num_leaves, complicated_spec
)
),
)
serialized_spec = py_pytree.treespec_dumps(complicated_spec)
saved_spec = (
'[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", '
'"children_spec": [{"type": "builtins.tuple", "context": "null", '
'"children_spec": [{"type": null, "context": null, '
'"children_spec": []}, {"type": null, "context": null, '
'"children_spec": []}]}, {"type": null, "context": null, '
'"children_spec": []}, {"type": "builtins.dict", "context": '
'"[4, 5, 6]", "children_spec": [{"type": null, "context": null, '
'"children_spec": []}, {"type": null, "context": null, "children_spec": '
'[]}, {"type": null, "context": null, "children_spec": []}]}]}]'
)
self.assertEqual(serialized_spec, saved_spec)
self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec))
def test_tree_map_with_path(self):
tree = [{i: i for i in range(10)}]
all_zeros = py_pytree.tree_map_with_path(
lambda kp, val: val - kp[1].key + kp[0].idx, tree
)
self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
def test_tree_map_with_path_multiple_trees(self):
@dataclass
class ACustomPytree:
x: Any
y: Any
z: Any
tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5]
tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2]
py_pytree.register_pytree_node(
ACustomPytree,
flatten_fn=lambda f: ([f.x, f.y], f.z),
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
)
from_two_trees = py_pytree.tree_map_with_path(
lambda kp, a, b: a + b, tree1, tree2
)
from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
self.assertEqual(from_two_trees, from_one_tree)
@skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
def test_tree_flatten_with_path_is_leaf(self):
leaf_dict = {"foo": [(3)]}
pytree = (["hello", [1, 2], leaf_dict],)
key_leaves, spec = py_pytree.tree_flatten_with_path(
pytree, is_leaf=lambda x: isinstance(x, dict)
)
self.assertTrue(key_leaves[-1][1] is leaf_dict)
def test_tree_flatten_with_path_roundtrip(self):
class ANamedTuple(NamedTuple):
x: torch.Tensor
y: int
z: str
@dataclass
class ACustomPytree:
x: Any
y: Any
z: Any
py_pytree.register_pytree_node(
ACustomPytree,
flatten_fn=lambda f: ([f.x, f.y], f.z),
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
)
SOME_PYTREES = [
(None,),
["hello", [1, 2], {"foo": [(3)]}],
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
]
for pytree in SOME_PYTREES:
key_leaves, spec = py_pytree.tree_flatten_with_path(pytree)
actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec)
self.assertEqual(actual, pytree)
def test_tree_leaves_with_path(self):
class ANamedTuple(NamedTuple):
x: torch.Tensor
y: int
z: str
@dataclass
class ACustomPytree:
x: Any
y: Any
z: Any
py_pytree.register_pytree_node(
ACustomPytree,
flatten_fn=lambda f: ([f.x, f.y], f.z),
unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
)
SOME_PYTREES = [
(None,),
["hello", [1, 2], {"foo": [(3)]}],
[ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
[ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
]
for pytree in SOME_PYTREES:
flat_out, _ = py_pytree.tree_flatten_with_path(pytree)
leaves_out = py_pytree.tree_leaves_with_path(pytree)
self.assertEqual(flat_out, leaves_out)
def test_key_str(self):
class ANamedTuple(NamedTuple):
x: str
y: int
tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
flat, _ = py_pytree.tree_flatten_with_path(tree)
paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat]
self.assertEqual(
paths,
[
"[0][0]: hello",
"[0][1][0]: 1",
"[0][1][1]: 2",
"[0][2]['foo'][0]: 3",
"[0][2]['bar'][0].x: baz",
"[0][2]['bar'][0].y: 10",
],
)
@skipIfTorchDynamo("AssertionError in dynamo")
def test_flatten_flatten_with_key_consistency(self):
"""Check that flatten and flatten_with_key produces consistent leaves/context."""
reg = py_pytree.SUPPORTED_NODES
EXAMPLE_TREE = {
list: [1, 2, 3],
tuple: (1, 2, 3),
dict: {"foo": 1, "bar": 2},
namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2),
OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
deque: deque([1, 2, 3]),
torch.Size: torch.Size([1, 2, 3]),
immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
immutable_list: immutable_list([1, 2, 3]),
}
for typ in reg:
example = EXAMPLE_TREE.get(typ)
if example is None:
continue
flat_with_path, spec1 = py_pytree.tree_flatten_with_path(example)
flat, spec2 = py_pytree.tree_flatten(example)
self.assertEqual(flat, [x[1] for x in flat_with_path])
self.assertEqual(spec1, spec2)
def test_key_access(self):
class ANamedTuple(NamedTuple):
x: str
y: int
tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
flat, _ = py_pytree.tree_flatten_with_path(tree)
for kp, val in flat:
self.assertEqual(py_pytree.key_get(tree, kp), val)
class TestCxxPytree(TestCase):
def setUp(self):
if IS_FBCODE:
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
def test_treespec_equality(self):
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
def test_treespec_repr(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = cxx_pytree.tree_flatten(pytree)
self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)")
@unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
def test_treespec_repr_dynamo(self):
# Check that it looks sane
pytree = (0, [0, 0, [0]])
_, spec = cxx_pytree.tree_flatten(pytree)
self.assertExpectedInline(
repr(spec),
"PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
)
@parametrize(
"spec",
[
cxx_pytree.tree_structure([]),
cxx_pytree.tree_structure(()),
cxx_pytree.tree_structure({}),
cxx_pytree.tree_structure([0]),
cxx_pytree.tree_structure([0, 1]),
cxx_pytree.tree_structure((0, 1, 2)),
cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}),
cxx_pytree.tree_structure(
OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
),
cxx_pytree.tree_structure([(0, 1, [2, 3])]),
cxx_pytree.tree_structure(
defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})
),
],
)
def test_pytree_serialize(self, spec):
self.assertEqual(
spec,
cxx_pytree.tree_structure(
cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec)
),
)
serialized_spec = cxx_pytree.treespec_dumps(spec)
self.assertIsInstance(serialized_spec, str)
self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
py_pytree._register_namedtuple(
GlobalPoint,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint",
)
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
LocalPoint = namedtuple("LocalPoint", ["x", "y"])
py_pytree._register_namedtuple(
LocalPoint,
serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint",
)
spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
def test_pytree_custom_type_serialize(self):
cxx_pytree.register_pytree_node(
GlobalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: GlobalDummyType(*xs),
serialized_type_name="GlobalDummyType",
)
spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
class LocalDummyType:
def __init__(self, x, y):
self.x = x
self.y = y
cxx_pytree.register_pytree_node(
LocalDummyType,
lambda dummy: ([dummy.x, dummy.y], None),
lambda xs, _: LocalDummyType(*xs),
serialized_type_name="LocalDummyType",
)
spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
serialized_spec = cxx_pytree.treespec_dumps(spec)
roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
self.assertEqual(roundtrip_spec, spec)
instantiate_parametrized_tests(TestGenericPytree)
instantiate_parametrized_tests(TestPythonPytree)
instantiate_parametrized_tests(TestCxxPytree)
if __name__ == "__main__":
run_tests()