team-10/env/Lib/site-packages/torch/_dynamo/polyfills/pytree.py
2025-08-02 07:34:44 +02:00

418 lines
15 KiB
Python

"""
Python polyfills for torch.utils.pytree
"""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, TYPE_CHECKING
from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
from ..decorators import substitute_in_graph
if TYPE_CHECKING:
import builtins
from collections.abc import Iterable
from typing_extensions import Self
__all__: list[str] = []
if python_pytree._cxx_pytree_dynamo_traceable:
import optree
import optree._C
import torch.utils._cxx_pytree as cxx_pytree
if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree
@substitute_in_graph(
optree._C.is_dict_insertion_ordered,
can_constant_fold_through=True,
)
def _(*args: Any, **kwargs: Any) -> bool:
# In namespace 'torch', the dictionary is always traversed in insertion order.
# This function returns True.
raise ValueError(
"Should not be called directly "
"because the original function will be called in the constant fold path."
)
__name = ""
for __name in (
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
"namedtuple_fields",
"structseq_fields",
):
__func = getattr(optree, __name)
substitute_in_graph(__func, can_constant_fold_through=True)(
__func.__python_implementation__
)
del __func
del __name
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
def tree_is_leaf(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
if tree is None or (is_leaf is not None and is_leaf(tree)):
return True
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
return True
return False
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
def tree_iter(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> Iterable[Any]:
stack = [tree]
while stack:
node = stack.pop()
if tree_is_leaf(node, is_leaf=is_leaf):
yield node
continue
children, *_ = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
stack.extend(reversed(children))
__all__ += ["tree_iter"]
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
def tree_leaves(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any]:
return list(tree_iter(tree, is_leaf=is_leaf))
__all__ += ["tree_leaves"]
class _Asterisk(str):
__slots__ = ()
def __new__(cls) -> Self:
return super().__new__(cls, "*")
def __repr__(self) -> str:
return "*" # no quotes
_asterisk = _Asterisk()
del _Asterisk
@dataclass(frozen=True)
class PyTreeSpec:
"""Analog for :class:`optree.PyTreeSpec` in Python."""
_children: tuple[PyTreeSpec, ...]
_type: builtins.type | None
_metadata: Any
_entries: tuple[Any, ...]
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
num_nodes: int = field(init=False)
num_leaves: int = field(init=False)
num_children: int = field(init=False)
none_is_leaf: Literal[True] = field(init=False)
namespace: Literal["torch"] = field(init=False)
def __post_init__(self) -> None:
if self._type is None:
assert len(self._children) == 0
assert self._metadata is None
assert self._entries == ()
assert self._unflatten_func is None
num_nodes = 1
num_leaves = 1
num_children = 0
else:
assert callable(self._unflatten_func)
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
num_leaves = sum(spec.num_leaves for spec in self._children)
num_children = len(self._children)
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
object.__setattr__(self, "none_is_leaf", True)
object.__setattr__(self, "namespace", "torch")
def __repr__(self) -> str:
def helper(treespec: PyTreeSpec) -> str:
if treespec.is_leaf():
assert treespec.type is None
return _asterisk
assert treespec.type is not None
assert callable(treespec._unflatten_func)
children_representations = [
helper(subspec) for subspec in treespec._children
]
if (
treespec.type in BUILTIN_TYPES
or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type)
):
return treespec._unflatten_func(
treespec._metadata,
children_representations,
)
return (
f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
f"[{', '.join(children_representations)}])"
)
return (
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
)
def __len__(self) -> int:
return self.num_leaves
@property
def type(self) -> builtins.type | None:
return self._type
def is_leaf(self) -> bool:
return self.num_nodes == 1 and self.num_leaves == 1
def children(self) -> list[PyTreeSpec]:
return list(self._children)
def child(self, index: int) -> PyTreeSpec:
return self._children[index]
def entries(self) -> list[Any]:
return list(self._entries)
def entry(self, index: int) -> Any:
return self._entries[index]
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(
treespec: PyTreeSpec,
node: PyTree,
subtrees: list[PyTree],
) -> None:
if treespec.is_leaf():
subtrees.append(node)
return
node_type = type(node)
if treespec.type not in BUILTIN_TYPES:
# Always require custom node types to match exactly
if node_type != treespec.type:
raise ValueError(
f"Type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
)
if len(children) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(children)}.",
)
if metadata != treespec._metadata:
raise ValueError(
f"Node context mismatch for custom node type {treespec.type!r}.",
)
else:
# For builtin dictionary types, we allow some flexibility
# Otherwise, we require exact matches
both_standard_dict = (
treespec.type in STANDARD_DICT_TYPES
and node_type in STANDARD_DICT_TYPES
)
if not both_standard_dict and node_type != treespec.type:
raise ValueError(
f"Node type mismatch; "
f"expected {treespec.type!r}, but got {node_type!r}.",
)
if len(node) != treespec.num_children:
raise ValueError(
f"Node arity mismatch; "
f"expected {treespec.num_children}, but got {len(node)}.",
)
if both_standard_dict:
# dictionary types are compatible with each other
expected_keys = treespec.entries()
got_key_set = set(node)
expected_key_set = set(expected_keys)
if got_key_set != expected_key_set:
missing_keys = expected_key_set.difference(got_key_set)
extra_keys = got_key_set.difference(expected_key_set)
message = ""
if missing_keys:
message += f"; missing key(s): {missing_keys}"
if extra_keys:
message += f"; extra key(s): {extra_keys}"
raise ValueError(f"Node keys mismatch{message}.")
children = [node[key] for key in expected_keys]
else:
# node_type is treespec.type
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
)
if (
node_type
is not deque # ignore mismatch of `maxlen` for deque
) and metadata != treespec._metadata:
raise ValueError(
f"Node metadata mismatch for node type {treespec.type!r}; "
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec._children):
helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = []
helper(self, tree, subtrees)
return subtrees
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
if not isinstance(leaves, (list, tuple)):
leaves = list(leaves)
if len(leaves) != self.num_leaves:
raise ValueError(
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
f"but the spec refers to a pytree that holds {self.num_leaves} "
f"items ({self}).",
)
if self.is_leaf():
return leaves[0]
# Recursively unflatten the children
start = 0
end = 0
subtrees = []
for subspec in self._children:
end += subspec.num_leaves
subtrees.append(subspec.unflatten(leaves[start:end]))
start = end
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> tuple[list[Any], PyTreeSpec]:
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
if tree_is_leaf(node, is_leaf=is_leaf):
leaves.append(node)
return _LEAF_SPEC
(
children,
metadata,
entries,
unflatten_func,
) = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
)
# Recursively flatten the children
subspecs = tuple(helper(child, leaves) for child in children)
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
leaves: list[Any] = []
treespec = helper(tree, leaves)
return leaves, treespec
__all__ += ["tree_flatten"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_structure,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_structure(
tree: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTreeSpec:
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
__all__ += ["tree_structure"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return treespec.unflatten(leaves)
__all__ += ["tree_unflatten"]
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
def tree_map(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args))
__all__ += ["tree_map"]
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree
__all__ += ["tree_map_"]