""" Python polyfills for torch.utils.pytree """ from __future__ import annotations from collections import deque from dataclasses import dataclass, field from typing import Any, Callable, Literal, TYPE_CHECKING from typing_extensions import TypeIs import torch.utils._pytree as python_pytree from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES from ..decorators import substitute_in_graph if TYPE_CHECKING: import builtins from collections.abc import Iterable from typing_extensions import Self __all__: list[str] = [] if python_pytree._cxx_pytree_dynamo_traceable: import optree import optree._C import torch.utils._cxx_pytree as cxx_pytree if TYPE_CHECKING: from torch.utils._cxx_pytree import PyTree @substitute_in_graph( optree._C.is_dict_insertion_ordered, can_constant_fold_through=True, ) def _(*args: Any, **kwargs: Any) -> bool: # In namespace 'torch', the dictionary is always traversed in insertion order. # This function returns True. raise ValueError( "Should not be called directly " "because the original function will be called in the constant fold path." ) __name = "" for __name in ( "is_namedtuple", "is_namedtuple_class", "is_namedtuple_instance", "is_structseq", "is_structseq_class", "is_structseq_instance", "namedtuple_fields", "structseq_fields", ): __func = getattr(optree, __name) substitute_in_graph(__func, can_constant_fold_through=True)( __func.__python_implementation__ ) del __func del __name @substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True) def tree_is_leaf( tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: if tree is None or (is_leaf is not None and is_leaf(tree)): return True if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined] return True return False @substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False) def tree_iter( tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None, ) -> Iterable[Any]: stack = [tree] while stack: node = stack.pop() if tree_is_leaf(node, is_leaf=is_leaf): yield node continue children, *_ = optree.tree_flatten_one_level( node, is_leaf=is_leaf, none_is_leaf=True, namespace="torch", ) stack.extend(reversed(children)) __all__ += ["tree_iter"] @substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True) def tree_leaves( tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None, ) -> list[Any]: return list(tree_iter(tree, is_leaf=is_leaf)) __all__ += ["tree_leaves"] class _Asterisk(str): __slots__ = () def __new__(cls) -> Self: return super().__new__(cls, "*") def __repr__(self) -> str: return "*" # no quotes _asterisk = _Asterisk() del _Asterisk @dataclass(frozen=True) class PyTreeSpec: """Analog for :class:`optree.PyTreeSpec` in Python.""" _children: tuple[PyTreeSpec, ...] _type: builtins.type | None _metadata: Any _entries: tuple[Any, ...] _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None num_nodes: int = field(init=False) num_leaves: int = field(init=False) num_children: int = field(init=False) none_is_leaf: Literal[True] = field(init=False) namespace: Literal["torch"] = field(init=False) def __post_init__(self) -> None: if self._type is None: assert len(self._children) == 0 assert self._metadata is None assert self._entries == () assert self._unflatten_func is None num_nodes = 1 num_leaves = 1 num_children = 0 else: assert callable(self._unflatten_func) num_nodes = sum((spec.num_nodes for spec in self._children), start=1) num_leaves = sum(spec.num_leaves for spec in self._children) num_children = len(self._children) object.__setattr__(self, "num_nodes", num_nodes) object.__setattr__(self, "num_leaves", num_leaves) object.__setattr__(self, "num_children", num_children) object.__setattr__(self, "none_is_leaf", True) object.__setattr__(self, "namespace", "torch") def __repr__(self) -> str: def helper(treespec: PyTreeSpec) -> str: if treespec.is_leaf(): assert treespec.type is None return _asterisk assert treespec.type is not None assert callable(treespec._unflatten_func) children_representations = [ helper(subspec) for subspec in treespec._children ] if ( treespec.type in BUILTIN_TYPES or optree.is_namedtuple_class(treespec.type) or optree.is_structseq_class(treespec.type) ): return treespec._unflatten_func( treespec._metadata, children_representations, ) return ( f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], " f"[{', '.join(children_representations)}])" ) return ( f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})" ) def __len__(self) -> int: return self.num_leaves @property def type(self) -> builtins.type | None: return self._type def is_leaf(self) -> bool: return self.num_nodes == 1 and self.num_leaves == 1 def children(self) -> list[PyTreeSpec]: return list(self._children) def child(self, index: int) -> PyTreeSpec: return self._children[index] def entries(self) -> list[Any]: return list(self._entries) def entry(self, index: int) -> Any: return self._entries[index] def flatten_up_to(self, tree: PyTree) -> list[PyTree]: def helper( treespec: PyTreeSpec, node: PyTree, subtrees: list[PyTree], ) -> None: if treespec.is_leaf(): subtrees.append(node) return node_type = type(node) if treespec.type not in BUILTIN_TYPES: # Always require custom node types to match exactly if node_type != treespec.type: raise ValueError( f"Type mismatch; " f"expected {treespec.type!r}, but got {node_type!r}.", ) children, metadata, *_ = optree.tree_flatten_one_level( node, none_is_leaf=True, namespace="torch", ) if len(children) != treespec.num_children: raise ValueError( f"Node arity mismatch; " f"expected {treespec.num_children}, but got {len(children)}.", ) if metadata != treespec._metadata: raise ValueError( f"Node context mismatch for custom node type {treespec.type!r}.", ) else: # For builtin dictionary types, we allow some flexibility # Otherwise, we require exact matches both_standard_dict = ( treespec.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES ) if not both_standard_dict and node_type != treespec.type: raise ValueError( f"Node type mismatch; " f"expected {treespec.type!r}, but got {node_type!r}.", ) if len(node) != treespec.num_children: raise ValueError( f"Node arity mismatch; " f"expected {treespec.num_children}, but got {len(node)}.", ) if both_standard_dict: # dictionary types are compatible with each other expected_keys = treespec.entries() got_key_set = set(node) expected_key_set = set(expected_keys) if got_key_set != expected_key_set: missing_keys = expected_key_set.difference(got_key_set) extra_keys = got_key_set.difference(expected_key_set) message = "" if missing_keys: message += f"; missing key(s): {missing_keys}" if extra_keys: message += f"; extra key(s): {extra_keys}" raise ValueError(f"Node keys mismatch{message}.") children = [node[key] for key in expected_keys] else: # node_type is treespec.type children, metadata, *_ = optree.tree_flatten_one_level( node, none_is_leaf=True, namespace="torch", ) if ( node_type is not deque # ignore mismatch of `maxlen` for deque ) and metadata != treespec._metadata: raise ValueError( f"Node metadata mismatch for node type {treespec.type!r}; " f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch ) for subtree, subspec in zip(children, treespec._children): helper(subspec, subtree, subtrees) subtrees: list[PyTree] = [] helper(self, tree, subtrees) return subtrees def unflatten(self, leaves: Iterable[Any]) -> PyTree: if not isinstance(leaves, (list, tuple)): leaves = list(leaves) if len(leaves) != self.num_leaves: raise ValueError( f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " f"but the spec refers to a pytree that holds {self.num_leaves} " f"items ({self}).", ) if self.is_leaf(): return leaves[0] # Recursively unflatten the children start = 0 end = 0 subtrees = [] for subspec in self._children: end += subspec.num_leaves subtrees.append(subspec.unflatten(leaves[start:end])) start = end assert callable(self._unflatten_func) return self._unflatten_func(self._metadata, subtrees) _LEAF_SPEC = PyTreeSpec((), None, None, (), None) def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: return isinstance(obj, PyTreeSpec) @substitute_in_graph( # type: ignore[arg-type] cxx_pytree.tree_flatten, # We need to disable constant folding here because we want the function to reference the # PyTreeSpec class defined above, not the one in the C++ module. can_constant_fold_through=False, ) def tree_flatten( tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None, ) -> tuple[list[Any], PyTreeSpec]: def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: if tree_is_leaf(node, is_leaf=is_leaf): leaves.append(node) return _LEAF_SPEC ( children, metadata, entries, unflatten_func, ) = optree.tree_flatten_one_level( node, is_leaf=is_leaf, none_is_leaf=True, namespace="torch", ) # Recursively flatten the children subspecs = tuple(helper(child, leaves) for child in children) return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type] leaves: list[Any] = [] treespec = helper(tree, leaves) return leaves, treespec __all__ += ["tree_flatten"] @substitute_in_graph( # type: ignore[arg-type] cxx_pytree.tree_structure, # We need to disable constant folding here because we want the function to reference the # PyTreeSpec class defined above, not the one in the C++ module. can_constant_fold_through=False, ) def tree_structure( tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTreeSpec: return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value] __all__ += ["tree_structure"] @substitute_in_graph( # type: ignore[arg-type] cxx_pytree.tree_unflatten, # We need to disable constant folding here because we want the function to reference the # PyTreeSpec class defined above, not the one in the C++ module. can_constant_fold_through=False, ) def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree: if not _is_pytreespec_instance(treespec): raise TypeError( f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " f"PyTreeSpec but got item of type {type(treespec)}." ) return treespec.unflatten(leaves) __all__ += ["tree_unflatten"] @substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True) def tree_map( func: Callable[..., Any], tree: PyTree, *rests: PyTree, is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] return treespec.unflatten(map(func, *flat_args)) __all__ += ["tree_map"] @substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True) def tree_map_( func: Callable[..., Any], tree: PyTree, *rests: PyTree, is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable return tree __all__ += ["tree_map_"]