418 lines
15 KiB
Python
418 lines
15 KiB
Python
"""
|
|
Python polyfills for torch.utils.pytree
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections import deque
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Literal, TYPE_CHECKING
|
|
from typing_extensions import TypeIs
|
|
|
|
import torch.utils._pytree as python_pytree
|
|
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
|
|
|
|
from ..decorators import substitute_in_graph
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import builtins
|
|
from collections.abc import Iterable
|
|
from typing_extensions import Self
|
|
|
|
|
|
__all__: list[str] = []
|
|
|
|
|
|
if python_pytree._cxx_pytree_dynamo_traceable:
|
|
import optree
|
|
import optree._C
|
|
|
|
import torch.utils._cxx_pytree as cxx_pytree
|
|
|
|
if TYPE_CHECKING:
|
|
from torch.utils._cxx_pytree import PyTree
|
|
|
|
@substitute_in_graph(
|
|
optree._C.is_dict_insertion_ordered,
|
|
can_constant_fold_through=True,
|
|
)
|
|
def _(*args: Any, **kwargs: Any) -> bool:
|
|
# In namespace 'torch', the dictionary is always traversed in insertion order.
|
|
# This function returns True.
|
|
raise ValueError(
|
|
"Should not be called directly "
|
|
"because the original function will be called in the constant fold path."
|
|
)
|
|
|
|
__name = ""
|
|
for __name in (
|
|
"is_namedtuple",
|
|
"is_namedtuple_class",
|
|
"is_namedtuple_instance",
|
|
"is_structseq",
|
|
"is_structseq_class",
|
|
"is_structseq_instance",
|
|
"namedtuple_fields",
|
|
"structseq_fields",
|
|
):
|
|
__func = getattr(optree, __name)
|
|
substitute_in_graph(__func, can_constant_fold_through=True)(
|
|
__func.__python_implementation__
|
|
)
|
|
del __func
|
|
del __name
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
|
|
def tree_is_leaf(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> bool:
|
|
if tree is None or (is_leaf is not None and is_leaf(tree)):
|
|
return True
|
|
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
|
|
return True
|
|
return False
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
|
|
def tree_iter(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> Iterable[Any]:
|
|
stack = [tree]
|
|
while stack:
|
|
node = stack.pop()
|
|
if tree_is_leaf(node, is_leaf=is_leaf):
|
|
yield node
|
|
continue
|
|
|
|
children, *_ = optree.tree_flatten_one_level(
|
|
node,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=True,
|
|
namespace="torch",
|
|
)
|
|
stack.extend(reversed(children))
|
|
|
|
__all__ += ["tree_iter"]
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
|
|
def tree_leaves(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> list[Any]:
|
|
return list(tree_iter(tree, is_leaf=is_leaf))
|
|
|
|
__all__ += ["tree_leaves"]
|
|
|
|
class _Asterisk(str):
|
|
__slots__ = ()
|
|
|
|
def __new__(cls) -> Self:
|
|
return super().__new__(cls, "*")
|
|
|
|
def __repr__(self) -> str:
|
|
return "*" # no quotes
|
|
|
|
_asterisk = _Asterisk()
|
|
del _Asterisk
|
|
|
|
@dataclass(frozen=True)
|
|
class PyTreeSpec:
|
|
"""Analog for :class:`optree.PyTreeSpec` in Python."""
|
|
|
|
_children: tuple[PyTreeSpec, ...]
|
|
_type: builtins.type | None
|
|
_metadata: Any
|
|
_entries: tuple[Any, ...]
|
|
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
|
|
|
|
num_nodes: int = field(init=False)
|
|
num_leaves: int = field(init=False)
|
|
num_children: int = field(init=False)
|
|
none_is_leaf: Literal[True] = field(init=False)
|
|
namespace: Literal["torch"] = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
if self._type is None:
|
|
assert len(self._children) == 0
|
|
assert self._metadata is None
|
|
assert self._entries == ()
|
|
assert self._unflatten_func is None
|
|
num_nodes = 1
|
|
num_leaves = 1
|
|
num_children = 0
|
|
else:
|
|
assert callable(self._unflatten_func)
|
|
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
|
|
num_leaves = sum(spec.num_leaves for spec in self._children)
|
|
num_children = len(self._children)
|
|
|
|
object.__setattr__(self, "num_nodes", num_nodes)
|
|
object.__setattr__(self, "num_leaves", num_leaves)
|
|
object.__setattr__(self, "num_children", num_children)
|
|
object.__setattr__(self, "none_is_leaf", True)
|
|
object.__setattr__(self, "namespace", "torch")
|
|
|
|
def __repr__(self) -> str:
|
|
def helper(treespec: PyTreeSpec) -> str:
|
|
if treespec.is_leaf():
|
|
assert treespec.type is None
|
|
return _asterisk
|
|
|
|
assert treespec.type is not None
|
|
assert callable(treespec._unflatten_func)
|
|
children_representations = [
|
|
helper(subspec) for subspec in treespec._children
|
|
]
|
|
if (
|
|
treespec.type in BUILTIN_TYPES
|
|
or optree.is_namedtuple_class(treespec.type)
|
|
or optree.is_structseq_class(treespec.type)
|
|
):
|
|
return treespec._unflatten_func(
|
|
treespec._metadata,
|
|
children_representations,
|
|
)
|
|
return (
|
|
f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], "
|
|
f"[{', '.join(children_representations)}])"
|
|
)
|
|
|
|
return (
|
|
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
|
|
)
|
|
|
|
def __len__(self) -> int:
|
|
return self.num_leaves
|
|
|
|
@property
|
|
def type(self) -> builtins.type | None:
|
|
return self._type
|
|
|
|
def is_leaf(self) -> bool:
|
|
return self.num_nodes == 1 and self.num_leaves == 1
|
|
|
|
def children(self) -> list[PyTreeSpec]:
|
|
return list(self._children)
|
|
|
|
def child(self, index: int) -> PyTreeSpec:
|
|
return self._children[index]
|
|
|
|
def entries(self) -> list[Any]:
|
|
return list(self._entries)
|
|
|
|
def entry(self, index: int) -> Any:
|
|
return self._entries[index]
|
|
|
|
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
|
def helper(
|
|
treespec: PyTreeSpec,
|
|
node: PyTree,
|
|
subtrees: list[PyTree],
|
|
) -> None:
|
|
if treespec.is_leaf():
|
|
subtrees.append(node)
|
|
return
|
|
|
|
node_type = type(node)
|
|
if treespec.type not in BUILTIN_TYPES:
|
|
# Always require custom node types to match exactly
|
|
if node_type != treespec.type:
|
|
raise ValueError(
|
|
f"Type mismatch; "
|
|
f"expected {treespec.type!r}, but got {node_type!r}.",
|
|
)
|
|
|
|
children, metadata, *_ = optree.tree_flatten_one_level(
|
|
node,
|
|
none_is_leaf=True,
|
|
namespace="torch",
|
|
)
|
|
if len(children) != treespec.num_children:
|
|
raise ValueError(
|
|
f"Node arity mismatch; "
|
|
f"expected {treespec.num_children}, but got {len(children)}.",
|
|
)
|
|
if metadata != treespec._metadata:
|
|
raise ValueError(
|
|
f"Node context mismatch for custom node type {treespec.type!r}.",
|
|
)
|
|
else:
|
|
# For builtin dictionary types, we allow some flexibility
|
|
# Otherwise, we require exact matches
|
|
both_standard_dict = (
|
|
treespec.type in STANDARD_DICT_TYPES
|
|
and node_type in STANDARD_DICT_TYPES
|
|
)
|
|
if not both_standard_dict and node_type != treespec.type:
|
|
raise ValueError(
|
|
f"Node type mismatch; "
|
|
f"expected {treespec.type!r}, but got {node_type!r}.",
|
|
)
|
|
if len(node) != treespec.num_children:
|
|
raise ValueError(
|
|
f"Node arity mismatch; "
|
|
f"expected {treespec.num_children}, but got {len(node)}.",
|
|
)
|
|
|
|
if both_standard_dict:
|
|
# dictionary types are compatible with each other
|
|
expected_keys = treespec.entries()
|
|
got_key_set = set(node)
|
|
expected_key_set = set(expected_keys)
|
|
if got_key_set != expected_key_set:
|
|
missing_keys = expected_key_set.difference(got_key_set)
|
|
extra_keys = got_key_set.difference(expected_key_set)
|
|
message = ""
|
|
if missing_keys:
|
|
message += f"; missing key(s): {missing_keys}"
|
|
if extra_keys:
|
|
message += f"; extra key(s): {extra_keys}"
|
|
raise ValueError(f"Node keys mismatch{message}.")
|
|
children = [node[key] for key in expected_keys]
|
|
else:
|
|
# node_type is treespec.type
|
|
children, metadata, *_ = optree.tree_flatten_one_level(
|
|
node,
|
|
none_is_leaf=True,
|
|
namespace="torch",
|
|
)
|
|
if (
|
|
node_type
|
|
is not deque # ignore mismatch of `maxlen` for deque
|
|
) and metadata != treespec._metadata:
|
|
raise ValueError(
|
|
f"Node metadata mismatch for node type {treespec.type!r}; "
|
|
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
|
|
)
|
|
|
|
for subtree, subspec in zip(children, treespec._children):
|
|
helper(subspec, subtree, subtrees)
|
|
|
|
subtrees: list[PyTree] = []
|
|
helper(self, tree, subtrees)
|
|
return subtrees
|
|
|
|
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
|
|
if not isinstance(leaves, (list, tuple)):
|
|
leaves = list(leaves)
|
|
if len(leaves) != self.num_leaves:
|
|
raise ValueError(
|
|
f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
|
|
f"but the spec refers to a pytree that holds {self.num_leaves} "
|
|
f"items ({self}).",
|
|
)
|
|
if self.is_leaf():
|
|
return leaves[0]
|
|
|
|
# Recursively unflatten the children
|
|
start = 0
|
|
end = 0
|
|
subtrees = []
|
|
for subspec in self._children:
|
|
end += subspec.num_leaves
|
|
subtrees.append(subspec.unflatten(leaves[start:end]))
|
|
start = end
|
|
|
|
assert callable(self._unflatten_func)
|
|
return self._unflatten_func(self._metadata, subtrees)
|
|
|
|
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
|
|
|
|
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
|
|
return isinstance(obj, PyTreeSpec)
|
|
|
|
@substitute_in_graph( # type: ignore[arg-type]
|
|
cxx_pytree.tree_flatten,
|
|
# We need to disable constant folding here because we want the function to reference the
|
|
# PyTreeSpec class defined above, not the one in the C++ module.
|
|
can_constant_fold_through=False,
|
|
)
|
|
def tree_flatten(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> tuple[list[Any], PyTreeSpec]:
|
|
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
|
|
if tree_is_leaf(node, is_leaf=is_leaf):
|
|
leaves.append(node)
|
|
return _LEAF_SPEC
|
|
|
|
(
|
|
children,
|
|
metadata,
|
|
entries,
|
|
unflatten_func,
|
|
) = optree.tree_flatten_one_level(
|
|
node,
|
|
is_leaf=is_leaf,
|
|
none_is_leaf=True,
|
|
namespace="torch",
|
|
)
|
|
|
|
# Recursively flatten the children
|
|
subspecs = tuple(helper(child, leaves) for child in children)
|
|
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
|
|
|
|
leaves: list[Any] = []
|
|
treespec = helper(tree, leaves)
|
|
return leaves, treespec
|
|
|
|
__all__ += ["tree_flatten"]
|
|
|
|
@substitute_in_graph( # type: ignore[arg-type]
|
|
cxx_pytree.tree_structure,
|
|
# We need to disable constant folding here because we want the function to reference the
|
|
# PyTreeSpec class defined above, not the one in the C++ module.
|
|
can_constant_fold_through=False,
|
|
)
|
|
def tree_structure(
|
|
tree: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> PyTreeSpec:
|
|
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
|
|
|
|
__all__ += ["tree_structure"]
|
|
|
|
@substitute_in_graph( # type: ignore[arg-type]
|
|
cxx_pytree.tree_unflatten,
|
|
# We need to disable constant folding here because we want the function to reference the
|
|
# PyTreeSpec class defined above, not the one in the C++ module.
|
|
can_constant_fold_through=False,
|
|
)
|
|
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
|
|
if not _is_pytreespec_instance(treespec):
|
|
raise TypeError(
|
|
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
|
|
f"PyTreeSpec but got item of type {type(treespec)}."
|
|
)
|
|
return treespec.unflatten(leaves)
|
|
|
|
__all__ += ["tree_unflatten"]
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
|
|
def tree_map(
|
|
func: Callable[..., Any],
|
|
tree: PyTree,
|
|
*rests: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> PyTree:
|
|
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
return treespec.unflatten(map(func, *flat_args))
|
|
|
|
__all__ += ["tree_map"]
|
|
|
|
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
|
|
def tree_map_(
|
|
func: Callable[..., Any],
|
|
tree: PyTree,
|
|
*rests: PyTree,
|
|
is_leaf: Callable[[PyTree], bool] | None = None,
|
|
) -> PyTree:
|
|
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
|
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
|
return tree
|
|
|
|
__all__ += ["tree_map_"]
|