team-10/env/Lib/site-packages/torch/export/_swap.py
2025-08-02 07:34:44 +02:00

438 lines
17 KiB
Python

import logging
import operator
import types
from collections import defaultdict
from typing import Optional
import torch
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.export.exported_program import (
ConstantArgument,
ExportedProgram,
ModuleCallSignature,
)
from torch.fx.passes.tools_common import legalize_graph, NodeList
from torch.fx.passes.utils.fuser_utils import erase_nodes, fuse_as_graphmodule
log = logging.getLogger(__name__)
def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]:
node_users = list(node.users.keys())
getitem_users = set()
for user in node_users:
if user.op == "output":
continue
assert (
user.op == "call_function" and user.target == operator.getitem
), f"Expected getitem node as user for {node}, instead got {user}"
getitem_users.update(list(user.users.keys()))
return getitem_users
def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
"""
We want to try to remove extraneous pytree flatten/unflatten calls between modules
calls. Instead of having the following:
graph():
...
%foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
%tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {})
%getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {})
%getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {})
%bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {})
...
We could do the following, if we know that all the outputs of `foo` feed into `bar`:
graph():
...
%foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
%bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {})
...
Currently this optimization only works for the case where all of the outputs
of `foo` go directly into `bar`, and `bar` has no other inputs.
""" # noqa: B950
log.debug("Trying to remove pytrees for module call %s", curr_module_node)
curr_module_users = list(curr_module_node.users.keys())
assert (
len(curr_module_users) == 1
), f"Expected only one user for module node, instead got {list(curr_module_users)}"
flatten_node = curr_module_users[0]
assert (
flatten_node.op == "call_function"
and flatten_node.target == fx_pytree.tree_flatten_spec
)
flatten_getitem_users = _get_getitem_users(flatten_node)
if len(flatten_getitem_users) != 1:
log.debug(
"More than one user found for flatten node, %s: %s. "
"Unable to fuse it with another unflatten call.",
flatten_node,
flatten_getitem_users,
)
return
unflatten_node = next(iter(flatten_getitem_users))
if not (
unflatten_node.op == "call_function"
and unflatten_node.target == pytree.tree_unflatten
):
log.debug(
"Flatten node %s's user is not a pytree.tree_unflatten. "
"Instead it is: %s. Passing...",
flatten_node,
unflatten_node,
)
return
for i, arg in enumerate(unflatten_node.args[0]): # type: ignore[union-attr,arg-type]
if arg not in flatten_node.users:
log.debug(
"Module %s's outputs are not all directly used as inputs to "
"the subsequent module. Unable to fuse the connecting "
"flatten/unflatten. The inputs to the subsequent module are: %s. ",
curr_module_node,
unflatten_node.args[0],
)
return
if not (
arg.op == "call_function"
and arg.target == operator.getitem
and arg.args[1] == i
):
log.debug(
"Module %s's outputs are not all directly used in the same "
"order as outputted. Unable to fuse the connecting "
"flatten/unflatten. The inputs to the "
"subsequent module are: %s. ",
curr_module_node,
unflatten_node.args[0],
)
return
# Unflatten has two levels of getitem, because it gets the args and kwargs
unflatten_getitem_getitem_users = set()
unflatten_getitem_users = _get_getitem_users(unflatten_node)
for unflatten_getitem_user in unflatten_getitem_users:
unflatten_getitem_getitem_users.update(
list(unflatten_getitem_user.users.keys())
)
if len(unflatten_getitem_getitem_users) != 1:
log.debug(
"More than one user found for unflatten node, %s: %s. "
"Unable to fuse it with another flatten call.",
unflatten_node,
unflatten_getitem_getitem_users,
)
return
next_module_node = next(iter(unflatten_getitem_getitem_users))
if not (next_module_node.op == "call_module"):
log.debug(
"Unflatten node %s's user is not a call_module. "
"Instead it is: %s. Passing...",
unflatten_node,
next_module_node,
)
return
# Directly put the outputs of the current module into the next module
next_module_node.args = (curr_module_node,)
def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None:
"""
Remove extraneous pytree flatten/unflatten calls.
We try a couple of optimizations here:
1. Remove pytree flatten/unflatten calls between modules
2. TODO: Remove module's in_spec + initial unflatten call
3. TODO: Remove module's out_spec + final flatten call
"""
for node in gm.graph.nodes:
if node.op == "call_module":
_try_remove_connecting_pytrees(node)
gm.graph.eliminate_dead_code()
def _construct_inputs(
gm: torch.fx.GraphModule,
signature: ModuleCallSignature,
node_name_map: dict[str, torch.fx.Node],
) -> tuple[list[torch.fx.Node], dict[str, torch.fx.Node]]:
tree_unflatten_args: list[Optional[torch.fx.Node]] = []
for input_ in signature.inputs:
if isinstance(input_, ConstantArgument) and input_.value is None:
# Constants should be directly embedded into the graph and not used
# as inputs
tree_unflatten_args.append(None)
elif input_.name not in node_name_map:
# For unused inputs
tree_unflatten_args.append(None)
else:
tree_unflatten_args.append(node_name_map[input_.name])
# Insert unflatten call
from .unflatten import _generate_unflatten
unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec)
assert signature.in_spec.num_children == 2
args_spec = signature.in_spec.children_specs[0]
assert args_spec.context is None
args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0))
args_nodes = [
gm.graph.call_function(operator.getitem, (args_node, i))
for i in range(args_spec.num_children)
]
kwargs_spec = signature.in_spec.children_specs[1]
assert kwargs_spec.context is not None
kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1))
kwargs_nodes = {
k: gm.graph.call_function(operator.getitem, (kwargs_node, k))
for k in kwargs_spec.context
}
return args_nodes, kwargs_nodes
def _insert_call_module(
gm: torch.fx.GraphModule,
args_nodes: list[torch.fx.Node],
kwargs_nodes: dict[str, torch.fx.Node],
module_to_swap: torch.nn.Module,
name: str,
) -> torch.fx.Node:
from .unflatten import _assign_attr, _AttrKind
_assign_attr(module_to_swap, gm, name, _AttrKind.MODULE)
module_node = gm.graph.call_module(name, tuple(args_nodes), kwargs_nodes) # type: ignore[arg-type]
return module_node
def _deconstruct_outputs(
gm: torch.fx.GraphModule,
signature: ModuleCallSignature,
module_node: torch.fx.Node,
node_name_map: dict[str, torch.fx.Node],
orig_outputs: tuple[torch.fx.Node, ...],
) -> None:
from .unflatten import _generate_flatten_spec
flatten_node = _generate_flatten_spec(gm, module_node, signature.out_spec)
for i, orig_output in enumerate(orig_outputs):
# Use Proxy to record getitem access.
proxy_out = torch.fx.Proxy(flatten_node)[i].node # type: ignore[index]
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
node_name_map[orig_output.name] = proxy_out
def _swap_module_helper(
gm: torch.fx.GraphModule,
modules_to_swap: dict[str, torch.nn.Module],
module_call_graph: dict[str, ModuleCallSignature],
) -> torch.fx.GraphModule:
log.debug("Starting graph:")
log.debug(gm.graph)
legalize_graph(gm)
partitions: dict[str, NodeList] = defaultdict(list)
node_name_map: dict[str, torch.fx.Node] = {
node.name: node for node in gm.graph.nodes
}
# TODO: Handle the duplicate module case
for node in gm.graph.nodes:
if nn_module_stack := node.meta.get("nn_module_stack"):
for path, _ in nn_module_stack.values():
if path in modules_to_swap:
partitions[path].append(node)
break
for name, nodes in partitions.items():
"""
Given a graph like the following, and we want to swap out the submodule "foo":
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=2] = placeholder[target=y]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {}), nn_module_stack = {"foo": ("foo", torch.nn.Module)}
%sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %add), kwargs = {}), nn_module_stack = {"bar": ("bar", torch.nn.Module)}
return (sub,)
We will first partition out foo's subgraph:
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=2] = placeholder[target=y]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {})
return add
And then insert an unflatten + call_module + flatten to replace the subgraph:
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
%tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {})
%getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {})
%getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {})
%getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten, 1), kwargs = {})
%foo : [num_users=0] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {})
%_spec_1 : [num_users=1] = get_attr[target=_spec_1]
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (None, %_spec_1), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
%sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %getitem_4), kwargs = {})
return (%sub,)
The `tree_unflatten` call will construct tensor inputs into the input
format needed by the swapped eager module.
The `call_module` node should now reference the swapped torch.nn.Module.
The `tree_flatten_spec` call will deconstruct the eager outputs of the
swapped module into tensors.
""" # noqa: B950
submod_name = name.replace(".", "_")
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
gm, nodes, f"fused_{submod_name}"
)
log.debug("Fused subgraph nodes:")
log.debug(sub_gm.graph)
signature: ModuleCallSignature = module_call_graph[name]
args_nodes, kwargs_nodes = _construct_inputs(gm, signature, node_name_map)
module_node = _insert_call_module(
gm, args_nodes, kwargs_nodes, modules_to_swap[name], name
)
_deconstruct_outputs(gm, signature, module_node, node_name_map, orig_outputs)
erase_nodes(gm, nodes)
log.debug("Swapped graph:")
log.debug(gm.graph)
legalize_graph(gm)
log.debug("Before removing extraneous pytrees:")
log.debug(gm.graph)
_remove_extraneous_pytrees(gm)
log.debug("After removing extraneous pytrees:")
log.debug(gm.graph)
gm.recompile()
return gm
def _fix_input_output_signature(
gm: torch.fx.GraphModule, signature: ModuleCallSignature
) -> None:
"""
Given the unlifted module from calling ep.module(), we want to remove the
pytree processing from the graph module's PyTreeCodeGen and instead make it
nodes inside of the graph. This allows us to do some optimizations, like
remove these pytree calls if it is unnecessary, and makes the PyTree part
more obvious to graph passes.
"""
from torch.export.unflatten import _generate_flatten, _generate_unflatten
# Remove the registered pytree codegen because we will take care of it
# through inserting pytree nodes into the graph
gm.graph._codegen = torch.fx.graph.CodeGen()
old_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
new_placeholders = []
forward_arg_names = signature.forward_arg_names
if forward_arg_names is None:
forward_arg_names = []
assert signature.in_spec.num_children == 2
arg_spec = signature.in_spec.children_specs[0]
kwarg_spec = signature.in_spec.children_specs[1]
assert arg_spec.type == tuple
assert kwarg_spec.type == dict
for i in range(arg_spec.num_children):
forward_arg_names.append(f"arg_{i}")
forward_arg_names.extend(kwarg_spec.context)
for arg in forward_arg_names:
with gm.graph.inserting_before(old_placeholders[0]):
new_placeholders.append(gm.graph.placeholder(arg))
# Insert flatten call for the inputs
with gm.graph.inserting_before(old_placeholders[0]):
flat_node = _generate_flatten(gm, tuple(new_placeholders))
for i, old_placeholder in enumerate(old_placeholders):
old_placeholder.op = "call_function"
old_placeholder.target = operator.getitem
old_placeholder.args = (flat_node, i)
# Insert unflatten call for the outputs
output_node = next(node for node in gm.graph.nodes if node.op == "output")
with gm.graph.inserting_before(output_node):
unflat = _generate_unflatten(gm, output_node.args[0], signature.out_spec)
output_node.args = (unflat,)
gm.recompile()
def _swap_modules(
ep: ExportedProgram, modules_to_swap: dict[str, torch.nn.Module]
) -> torch.fx.GraphModule:
"""
Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps
previously traced modules with new eager modules specified. Returns a
fx.GraphModule with a custom forward function.
Args:
ep (ExportedProgram): Exported program to modify
modules_to_swap (Dict[str, torch.nn.Module]): Mapping from module fqn to
eager module to swap with. The specified module fqn should have also
been specified in the `preserve_module_call_signature` argument to
torch.export so that we know how to restore the calling convention
to this argument.
run_with_interpreter: Whether or not to run the graph using
fx.Interpreter. Setting to true will help result in better error
messages and easier debugging, but it has found to result in a QPS
drop.
"""
module_call_graph = {
entry.fqn: entry.signature for entry in ep.module_call_graph if entry.signature
}
gm = ep.module()
gm.validate_inputs = False # type: ignore[assignment]
gm.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
assert isinstance(gm, torch.fx.GraphModule)
_fix_input_output_signature(gm, ep.module_call_graph[0].signature)
gm.module_call_graph = ep.module_call_graph
gm.train = types.MethodType(type(gm).train, gm) # type: ignore[assignment]
gm.eval = types.MethodType(type(gm).eval, gm) # type: ignore[assignment]
assert isinstance(gm, torch.fx.GraphModule)
gm = _swap_module_helper(gm, modules_to_swap, module_call_graph)
return gm