| """ |
| Contains utility functions for working with nested python data structures. |
| |
| A *pytree* is Python nested data structure. It is a tree in the sense that |
| nodes are Python collections (e.g., list, tuple, dict) and the leaves are |
| Python values. Furthermore, a pytree should not contain reference cycles. |
| |
| pytrees are useful for working with nested collections of Tensors. For example, |
| one can use `tree_map` to map a function over all Tensors inside some nested |
| collection of Tensors and `tree_leaves` to get a flat list of all Tensors |
| inside some nested collection. pytrees are helpful for implementing nested |
| collection support for PyTorch APIs. |
| """ |
| |
| import functools |
| from typing import ( |
| Any, |
| Callable, |
| Iterable, |
| List, |
| Optional, |
| overload, |
| Tuple, |
| Type, |
| TypeVar, |
| Union, |
| ) |
| |
| import optree |
| from optree import PyTreeSpec # direct import for type annotations |
| |
| |
| __all__ = [ |
| "PyTree", |
| "Context", |
| "FlattenFunc", |
| "UnflattenFunc", |
| "TreeSpec", |
| "LeafSpec", |
| "register_pytree_node", |
| "tree_flatten", |
| "tree_unflatten", |
| "tree_leaves", |
| "tree_structure", |
| "tree_map", |
| "tree_map_", |
| "tree_map_only", |
| "tree_map_only_", |
| "tree_all", |
| "tree_any", |
| "tree_all_only", |
| "tree_any_only", |
| "treespec_dumps", |
| "treespec_loads", |
| "treespec_pprint", |
| ] |
| |
| |
| T = TypeVar("T") |
| S = TypeVar("S") |
| U = TypeVar("U") |
| R = TypeVar("R") |
| |
| |
| Context = Optional[Any] |
| PyTree = Any |
| TreeSpec = PyTreeSpec |
| FlattenFunc = Callable[[PyTree], Tuple[List, Context]] |
| UnflattenFunc = Callable[[Iterable, Context], PyTree] |
| OpTreeUnflattenFunc = Callable[[Context, Iterable], PyTree] |
| |
| |
| def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc: |
| @functools.wraps(func) |
| def wrapped(*args: Any, **kwargs: Any) -> Any: |
| return func(*reversed(args), **kwargs) |
| |
| return wrapped |
| |
| |
| def register_pytree_node( |
| cls: Type[Any], |
| flatten_fn: FlattenFunc, |
| unflatten_fn: UnflattenFunc, |
| *, |
| serialized_type_name: Optional[str] = None, |
| namespace: str = "torch", |
| ) -> None: |
| """Extend the set of types that are considered internal nodes in pytrees. |
| |
| The ``namespace`` argument is used to avoid collisions that occur when different libraries |
| register the same Python type with different behaviors. It is recommended to add a unique prefix |
| to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify |
| the same class in different namespaces for different use cases. |
| |
| .. warning:: |
| For safety reasons, a ``namespace`` must be specified while registering a custom type. It is |
| used to isolate the behavior of flattening and unflattening a pytree node type. This is to |
| prevent accidental collisions between different libraries that may register the same type. |
| |
| Args: |
| cls (type): A Python type to treat as an internal pytree node. |
| flatten_fn (callable): A function to be used during flattening, taking an instance of |
| ``cls`` and returning a pair, with (1) an iterable for the children to be flattened |
| recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be |
| passed to the ``unflatten_fn``. |
| unflatten_fn (callable): A function taking two arguments: the auxiliary data that was |
| returned by ``flatten_fn`` and stored in the treespec, and the unflattened children. |
| The function should return an instance of ``cls``. |
| serialized_type_name (str, optional): A keyword argument used to specify the fully |
| qualified name used when serializing the tree spec. |
| namespace (str, optional): A non-empty string that uniquely identifies the namespace of the |
| type registry. This is used to isolate the registry from other modules that might |
| register a different custom behavior for the same type. (default: :const:`"torch"`) |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP |
| >>> # Registry a Python type with lambda functions |
| >>> register_pytree_node( |
| ... set, |
| ... lambda s: (sorted(s), None, None), |
| ... lambda children, _: set(children), |
| ... namespace='set', |
| ... ) |
| |
| >>> # xdoctest: +SKIP |
| >>> # Register a Python type into a namespace |
| >>> import torch |
| >>> register_pytree_node( |
| ... torch.Tensor, |
| ... flatten_func=lambda tensor: ( |
| ... (tensor.cpu().detach().numpy(),), |
| ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, |
| ... ), |
| ... unflatten_func=lambda children, metadata: torch.tensor(children[0], **metadata), |
| ... namespace='torch2numpy', |
| ... ) |
| |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} |
| >>> tree |
| {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} |
| |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> # Flatten without specifying the namespace |
| >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes # xdoctest: +SKIP |
| ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) |
| |
| >>> # xdoctest: +SKIP |
| >>> # Flatten with the namespace |
| >>> tree_flatten(tree, namespace='torch2numpy') # xdoctest: +SKIP |
| ( |
| [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], |
| PyTreeSpec( |
| { |
| 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]), |
| 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, ...}], [*]) |
| }, |
| namespace='torch2numpy' |
| ) |
| ) |
| |
| >>> # xdoctest: +SKIP |
| >>> # Register the same type with a different namespace for different behaviors |
| >>> def tensor2flatparam(tensor): |
| ... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None |
| ... |
| >>> def flatparam2tensor(children, metadata): |
| ... return children[0].reshape(metadata) |
| ... |
| >>> register_pytree_node( |
| ... torch.Tensor, |
| ... flatten_func=tensor2flatparam, |
| ... unflatten_func=flatparam2tensor, |
| ... namespace='tensor2flatparam', |
| ... ) |
| |
| >>> # xdoctest: +SKIP |
| >>> # Flatten with the new namespace |
| >>> tree_flatten(tree, namespace='tensor2flatparam') # xdoctest: +SKIP |
| ( |
| [ |
| Parameter containing: tensor([0., 0.], requires_grad=True), |
| Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True) |
| ], |
| PyTreeSpec( |
| { |
| 'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]), |
| 'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*]) |
| }, |
| namespace='tensor2flatparam' |
| ) |
| ) |
| """ |
| from ._pytree import _register_pytree_node |
| |
| _register_pytree_node( |
| cls, |
| flatten_fn, |
| unflatten_fn, |
| serialized_type_name=serialized_type_name, |
| ) |
| |
| optree.register_pytree_node( |
| cls, |
| flatten_fn, |
| _reverse_args(unflatten_fn), |
| namespace=namespace, |
| ) |
| |
| |
| _register_pytree_node = register_pytree_node |
| |
| |
| def tree_flatten( |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> Tuple[List[Any], TreeSpec]: |
| """Flatten a pytree. |
| |
| See also :func:`tree_unflatten`. |
| |
| The flattening order (i.e., the order of elements in the output list) is deterministic, |
| corresponding to a left-to-right depth-first tree traversal. |
| |
| >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
| >>> tree_flatten(tree) |
| ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)) |
| >>> tree_flatten(tree, none_is_leaf=False) |
| ([1, 2, 3, 4, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})) |
| >>> tree_flatten(1) |
| ([1], PyTreeSpec(*, NoneIsLeaf)) |
| >>> tree_flatten(None) |
| ([None], PyTreeSpec(*, NoneIsLeaf)) |
| >>> tree_flatten(None, none_is_leaf=False) |
| ([], PyTreeSpec(None)) |
| |
| For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is |
| dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` |
| if you want to keep the keys in the insertion order. |
| |
| >>> from collections import OrderedDict |
| >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) |
| >>> tree_flatten(tree) |
| ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)) |
| >>> tree_flatten(tree, none_is_leaf=False) |
| ([2, 3, 4, 1, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', None), ('d', *)]))) |
| |
| Args: |
| tree (pytree): A pytree to flatten. |
| none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
| :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
| treespec rather than in the leaves list. (default: :data:`True`) |
| namespace (str, optional): The registry namespace used for custom pytree node types. |
| (default: :const:`"torch"`) |
| |
| Returns: |
| A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the |
| second element is a treespec representing the structure of the pytree. |
| """ |
| return optree.tree_flatten( # type: ignore[return-value] |
| tree, |
| none_is_leaf=none_is_leaf, |
| namespace=namespace, |
| ) |
| |
| |
| def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: |
| """Reconstruct a pytree from the treespec and the leaves. |
| |
| The inverse of :func:`tree_flatten`. |
| |
| >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
| >>> leaves, treespec = tree_flatten(tree) |
| >>> tree == tree_unflatten(leaves, treespec) |
| True |
| |
| Args: |
| leaves (iterable): The list of leaves to use for reconstruction. The list must match the |
| number of leaves of the treespec. |
| treespec (TreeSpec): The treespec to reconstruct. |
| |
| Returns: |
| The reconstructed pytree, containing the ``leaves`` placed in the structure described by |
| ``treespec``. |
| """ |
| if not isinstance(treespec, TreeSpec): |
| raise TypeError( |
| f"tree_unflatten(values, spec): Expected `spec` to be instance of " |
| f"TreeSpec but got item of type {type(treespec)}." |
| ) |
| return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type] |
| |
| |
| def tree_leaves( |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> List[Any]: |
| """Get the leaves of a pytree. |
| |
| See also :func:`tree_flatten`. |
| |
| >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
| >>> tree_leaves(tree) |
| [1, 2, 3, 4, None, 5] |
| >>> tree_leaves(tree, none_is_leaf=False) |
| [1, 2, 3, 4, 5] |
| >>> tree_leaves(1) |
| [1] |
| >>> tree_leaves(None) |
| [None] |
| >>> tree_leaves(None, none_is_leaf=False) |
| [] |
| |
| Args: |
| tree (pytree): A pytree to flatten. |
| none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
| :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
| treespec rather than in the leaves list. (default: :data:`True`) |
| namespace (str, optional): The registry namespace used for custom pytree node types. |
| (default: :const:`"torch"`) |
| |
| Returns: |
| A list of leaf values. |
| """ |
| return optree.tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) |
| |
| |
| def tree_structure( |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> TreeSpec: |
| """Get the treespec for a pytree. |
| |
| See also :func:`tree_flatten`. |
| |
| >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
| >>> tree_structure(tree) |
| PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) |
| >>> tree_structure(tree, none_is_leaf=False) |
| PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) |
| >>> tree_structure(1) |
| PyTreeSpec(*, NoneIsLeaf) |
| >>> tree_structure(None) |
| PyTreeSpec(*, NoneIsLeaf) |
| >>> tree_structure(None, none_is_leaf=False) |
| PyTreeSpec(None) |
| |
| Args: |
| tree (pytree): A pytree to flatten. |
| none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
| :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
| treespec rather than in the leaves list. (default: :data:`True`) |
| namespace (str, optional): The registry namespace used for custom pytree node types. |
| (default: :const:`"torch"`) |
| |
| Returns: |
| A treespec object representing the structure of the pytree. |
| """ |
| return optree.tree_structure( # type: ignore[return-value] |
| tree, |
| none_is_leaf=none_is_leaf, |
| namespace=namespace, |
| ) |
| |
| |
| def tree_map( |
| func: Callable[..., Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| """Map a multi-input function over pytree args to produce a new pytree. |
| |
| See also :func:`tree_map_`. |
| |
| >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) |
| {'x': 8, 'y': (43, 65)} |
| >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) |
| {'x': False, 'y': (False, False), 'z': True} |
| >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=False) |
| {'x': 8, 'y': (43, 65), 'z': None} |
| >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=False) |
| {'x': False, 'y': (False, False), 'z': None} |
| |
| If multiple inputs are given, the structure of the tree is taken from the first input; |
| subsequent inputs need only have ``tree`` as a prefix: |
| |
| >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) |
| [[5, 7, 9], [6, 1, 2]] |
| |
| Args: |
| func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the |
| corresponding leaves of the pytrees. |
| tree (pytree): A pytree to be mapped over, with each leaf providing the first positional |
| argument to function ``func``. |
| rests (tuple of pytrees): A tuple of pytrees, each of which has the same structure as |
| ``tree`` or has ``tree`` as a prefix. |
| none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
| :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
| treespec rather than in the leaves list. (default: :data:`True`) |
| namespace (str, optional): The registry namespace used for custom pytree node types. |
| (default: :const:`"torch"`) |
| |
| Returns: |
| A new pytree with the same structure as ``tree`` but with the value at each leaf given by |
| ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` |
| is the tuple of values at corresponding nodes in ``rests``. |
| """ |
| return optree.tree_map( |
| func, |
| tree, |
| *rests, |
| none_is_leaf=none_is_leaf, |
| namespace=namespace, |
| ) |
| |
| |
| def tree_map_( |
| func: Callable[..., Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. |
| |
| See also :func:`tree_map`. |
| |
| Args: |
| func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the |
| corresponding leaves of the pytrees. |
| tree (pytree): A pytree to be mapped over, with each leaf providing the first positional |
| argument to function ``func``. |
| rests (tuple of pytrees): A tuple of pytrees, each of which has the same structure as |
| ``tree`` or has ``tree`` as a prefix. |
| none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
| :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
| treespec rather than in the leaves list. (default: :data:`True`) |
| namespace (str, optional): The registry namespace used for custom pytree node types. |
| (default: :const:`"torch"`) |
| |
| Returns: |
| The original ``tree`` with the value at each leaf is given by the side-effect of function |
| ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf |
| in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. |
| """ |
| return optree.tree_map_( |
| func, |
| tree, |
| *rests, |
| none_is_leaf=none_is_leaf, |
| namespace=namespace, |
| ) |
| |
| |
| Type2 = Tuple[Type[T], Type[S]] |
| Type3 = Tuple[Type[T], Type[S], Type[U]] |
| TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] |
| |
| Fn2 = Callable[[Union[T, S]], R] |
| Fn3 = Callable[[Union[T, S, U]], R] |
| Fn = Callable[[T], R] |
| FnAny = Callable[[Any], R] |
| |
| MapOnlyFn = Callable[[T], Callable[[Any], Any]] |
| |
| |
| # These specializations help with type inference on the lambda passed to this |
| # function |
| @overload |
| def map_only(__type_or_types: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: |
| ... |
| |
| |
| @overload |
| def map_only(__type_or_types: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: |
| ... |
| |
| |
| @overload |
| def map_only(__type_or_types: Type[T]) -> MapOnlyFn[Fn[T, Any]]: |
| ... |
| |
| |
| # This specialization is needed for the implementations below that call |
| @overload |
| def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]: |
| ... |
| |
| |
| def map_only(__type_or_types: TypeAny) -> MapOnlyFn[FnAny[Any]]: |
| """ |
| Suppose you are writing a tree_map over tensors, leaving everything |
| else unchanged. Ordinarily you would have to write: |
| |
| def go(t): |
| if isinstance(t, Tensor): |
| return ... |
| else: |
| return t |
| |
| With this function, you only need to write: |
| |
| @map_only(Tensor) |
| def go(t): |
| return ... |
| |
| You can also directly use 'tree_map_only' |
| """ |
| |
| def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: |
| @functools.wraps(func) |
| def wrapped(x: T) -> Any: |
| if isinstance(x, __type_or_types): |
| return func(x) |
| return x |
| |
| return wrapped |
| |
| return wrapper |
| |
| |
| @overload |
| def tree_map_only( |
| __type_or_types: Type[T], |
| func: Fn[T, Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| ... |
| |
| |
| @overload |
| def tree_map_only( |
| __type_or_types: Type2[T, S], |
| func: Fn2[T, S, Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| ... |
| |
| |
| @overload |
| def tree_map_only( |
| __type_or_types: Type3[T, S, U], |
| func: Fn3[T, S, U, Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| ... |
| |
| |
| def tree_map_only( |
| __type_or_types: TypeAny, |
| func: FnAny[Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| return tree_map( |
| map_only(__type_or_types)(func), |
| tree, |
| *rests, |
| none_is_leaf=none_is_leaf, |
| namespace=namespace, |
| ) |
| |
| |
| @overload |
| def tree_map_only_( |
| __type_or_types: Type[T], |
| func: Fn[T, Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| ... |
| |
| |
| @overload |
| def tree_map_only_( |
| __type_or_types: Type2[T, S], |
| func: Fn2[T, S, Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| ... |
| |
| |
| @overload |
| def tree_map_only_( |
| __type_or_types: Type3[T, S, U], |
| func: Fn3[T, S, U, Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| ... |
| |
| |
| def tree_map_only_( |
| __type_or_types: TypeAny, |
| func: FnAny[Any], |
| tree: PyTree, |
| *rests: PyTree, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> PyTree: |
| return tree_map_( |
| map_only(__type_or_types)(func), |
| tree, |
| *rests, |
| none_is_leaf=none_is_leaf, |
| namespace=namespace, |
| ) |
| |
| |
| def tree_all( |
| pred: Callable[[Any], bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) |
| return all(map(pred, flat_args)) |
| |
| |
| def tree_any( |
| pred: Callable[[Any], bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) |
| return any(map(pred, flat_args)) |
| |
| |
| @overload |
| def tree_all_only( |
| __type_or_types: Type[T], |
| pred: Fn[T, bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| ... |
| |
| |
| @overload |
| def tree_all_only( |
| __type_or_types: Type2[T, S], |
| pred: Fn2[T, S, bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| ... |
| |
| |
| @overload |
| def tree_all_only( |
| __type_or_types: Type3[T, S, U], |
| pred: Fn3[T, S, U, bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| ... |
| |
| |
| def tree_all_only( |
| __type_or_types: TypeAny, |
| pred: FnAny[bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) |
| return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) |
| |
| |
| @overload |
| def tree_any_only( |
| __type_or_types: Type[T], |
| pred: Fn[T, bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| ... |
| |
| |
| @overload |
| def tree_any_only( |
| __type_or_types: Type2[T, S], |
| pred: Fn2[T, S, bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| ... |
| |
| |
| @overload |
| def tree_any_only( |
| __type_or_types: Type3[T, S, U], |
| pred: Fn3[T, S, U, bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| ... |
| |
| |
| def tree_any_only( |
| __type_or_types: TypeAny, |
| pred: FnAny[bool], |
| tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> bool: |
| flat_args = tree_leaves(tree, none_is_leaf=none_is_leaf, namespace=namespace) |
| return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) |
| |
| |
| def broadcast_prefix( |
| prefix_tree: PyTree, |
| full_tree: PyTree, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> List[Any]: |
| """Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``. |
| |
| If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be |
| constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**. |
| |
| This function returns a list of leaves with the same size as ``full_tree``. The leaves are |
| replicated from ``prefix_tree``. The number of replicas is determined by the corresponding |
| subtree in ``full_tree``. |
| |
| >>> broadcast_prefix(1, [1, 2, 3]) |
| [1, 1, 1] |
| >>> broadcast_prefix([1, 2, 3], [1, 2, 3]) |
| [1, 2, 3] |
| >>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4]) |
| Traceback (most recent call last): |
| ... |
| ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4]. |
| >>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) |
| [1, 2, 3, 3] |
| >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) |
| [1, 2, 3, 3, 3, 3] |
| >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}], none_is_leaf=False) |
| [1, 2, 3, 3, 3] |
| |
| Args: |
| prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``. |
| full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``. |
| is_leaf (callable, optional): An optionally specified function that will be called at each |
| flattening step. It should return a boolean, with :data:`True` stopping the traversal |
| and the whole subtree being treated as a leaf, and :data:`False` indicating the |
| flattening should traverse the current object. |
| none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
| :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
| treespec rather than in the leaves list. (default: :data:`True`) |
| namespace (str, optional): The registry namespace used for custom pytree node types. |
| (default: :const:`"torch"`) |
| |
| Returns: |
| A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``. |
| """ |
| return optree.broadcast_prefix( |
| prefix_tree, |
| full_tree, |
| none_is_leaf=none_is_leaf, |
| namespace=namespace, |
| ) |
| |
| |
| # Broadcasts a pytree to the provided TreeSpec and returns the flattened |
| # values. If this is not possible, then this function returns None. |
| # |
| # For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), |
| # would return [0, 0]. This is useful for part of the vmap implementation: |
| # a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be |
| # broadcastable to the tree structure of `inputs` and we use |
| # _broadcast_to_and_flatten to check this. |
| def _broadcast_to_and_flatten( |
| tree: PyTree, |
| treespec: TreeSpec, |
| *, |
| none_is_leaf: bool = True, |
| namespace: str = "torch", |
| ) -> Optional[List[Any]]: |
| assert isinstance(treespec, TreeSpec) |
| full_tree = tree_unflatten([0] * treespec.num_leaves, treespec) |
| try: |
| return broadcast_prefix( |
| tree, |
| full_tree, |
| none_is_leaf=none_is_leaf, |
| namespace=namespace, |
| ) |
| except ValueError: |
| return None |
| |
| |
| def treespec_dumps(treespec: TreeSpec) -> str: |
| """Serialize a treespec to a JSON string.""" |
| if not isinstance(treespec, TreeSpec): |
| raise TypeError( |
| f"treespec_dumps(spec): Expected `spec` to be instance of " |
| f"TreeSpec but got item of type {type(treespec)}." |
| ) |
| from ._pytree import ( |
| tree_structure as _tree_structure, |
| treespec_dumps as _treespec_dumps, |
| ) |
| |
| orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec)) |
| return _treespec_dumps(orig_treespec) |
| |
| |
| def treespec_loads(serialized: str) -> TreeSpec: |
| """Deserialize a treespec from a JSON string.""" |
| from ._pytree import ( |
| tree_unflatten as _tree_unflatten, |
| treespec_loads as _treespec_loads, |
| ) |
| |
| orig_treespec = _treespec_loads(serialized) |
| dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec) |
| treespec = tree_structure(dummy_tree) |
| return treespec |
| |
| |
| class _DummyLeaf: |
| def __repr__(self) -> str: |
| return "*" |
| |
| |
| def treespec_pprint(treespec: TreeSpec) -> str: |
| dummy_tree = tree_unflatten( |
| [_DummyLeaf() for _ in range(treespec.num_leaves)], |
| treespec, |
| ) |
| return repr(dummy_tree) |
| |
| |
| class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc] |
| def __instancecheck__(self, instance: object) -> bool: |
| return isinstance(instance, TreeSpec) and instance.is_leaf() |
| |
| |
| class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): |
| def __new__(cls, none_is_leaf: bool = True) -> "LeafSpec": |
| return optree.treespec_leaf(none_is_leaf=none_is_leaf) # type: ignore[return-value] |