236 lines
9.2 KiB
Python
236 lines
9.2 KiB
Python
"""
|
|
This module implements graph deduplication functionality for TorchDynamo's optimization pipeline.
|
|
Graph deduplication identifies identical subgraphs in the computational graph and merges them
|
|
to reduce redundancy and improve performance. The process involves analyzing regions of the graph,
|
|
identifying structurally equivalent regions, and replacing them with a single shared implementation.
|
|
This optimization is particularly effective for models with repeated patterns or similar computational
|
|
structures across different parts of the network.
|
|
"""
|
|
|
|
import logging
|
|
import operator
|
|
from collections.abc import Iterable
|
|
from typing import Any
|
|
|
|
import torch.fx
|
|
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
|
|
from torch.utils._pytree import tree_flatten
|
|
|
|
from .graph_region_tracker import Node, Region
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore[no-untyped-def]
|
|
"""
|
|
This is the main entry point for applying the graph deduplication pass. \
|
|
Deduplication occurs in two phases:
|
|
1. Subgraph creation:
|
|
Subgraph creation works by taking one representative region from each region \
|
|
group and creating a subgraph from it, which will then be used to replace all regions \
|
|
in the group. This is implemented by first copying all nodes of the region to the new \
|
|
subgraph and then finding all inputs which are not within the region and creating placeholders \
|
|
for them. For the outputs, all regions in a region group need to be scanned to ensure the \
|
|
largest set of outputs is found, and then an output node is created which returns \
|
|
a tuple of all outputs.
|
|
|
|
2. Graph replacement:
|
|
To replace each region with the extracted subgraph, the node index in the region \
|
|
and argument index within the node's flattened args and kwargs are recorded once during \
|
|
subgraph creation. This allows us to determine which (external to the region) nodes and \
|
|
in which order these nodes are passed as inputs. For the outputs, getitem nodes are created \
|
|
for each output, and all nodes in the region with external outputs are replaced by the proper \
|
|
getitem node. Finally, all original nodes are erased (there should be no uses of these \
|
|
left in the graph).
|
|
|
|
The deduplication mutates the output_graph argument in place.
|
|
|
|
Returns a mapping of nodes to their subgraph output replacement node to remap outputs
|
|
when they are created in output_graph.
|
|
"""
|
|
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
|
|
output_graph.graph
|
|
)
|
|
|
|
# Used to track which nodes were replaced with subgraph outputs
|
|
# today, we have to register the new subgraph submodules before the
|
|
# graph outputs have been created, so we pass the replacement mapping
|
|
# back to output graph to do the replacements at the site of output creation
|
|
output_replacements: dict[Node, Node] = {}
|
|
for region_group in duplicated_region_groups:
|
|
inds_with_external_users = _get_all_output_indices(region_group)
|
|
region = region_group[0]
|
|
(
|
|
subgraph,
|
|
node_ind_arg_inds,
|
|
) = _create_subgraph(region, inds_with_external_users)
|
|
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
|
|
subgraph_name = output_graph.install_subgraph("subgraph", sub_gm)
|
|
with output_graph.graph.inserting_before():
|
|
get_subgraph_node = output_graph.graph.create_node(
|
|
"get_attr", subgraph_name, (), {}
|
|
)
|
|
for region in region_group:
|
|
_replace_region_with_subgraph(
|
|
output_graph.graph,
|
|
region,
|
|
get_subgraph_node,
|
|
node_ind_arg_inds.keys(),
|
|
inds_with_external_users,
|
|
sub_gm,
|
|
subgraph_name,
|
|
output_replacements,
|
|
)
|
|
|
|
return output_replacements
|
|
|
|
|
|
# flattens with support for slices
|
|
# Note: a better way to do this would
|
|
# be register/unregister slices as pytree nodes
|
|
# but there is no unregister API in the pytorch
|
|
# pytree impl
|
|
def _flatten_args_kwargs(args: Any) -> list[Node]:
|
|
fully_flattened = []
|
|
|
|
def flatten(args: Any) -> None:
|
|
flattened, _ = tree_flatten(args)
|
|
for arg in flattened:
|
|
if isinstance(arg, slice):
|
|
start = arg.start
|
|
stop = arg.stop
|
|
step = arg.step
|
|
flatten((start, stop, step))
|
|
else:
|
|
fully_flattened.append(arg)
|
|
|
|
flatten(args)
|
|
|
|
return fully_flattened
|
|
|
|
|
|
def _replace_region_with_subgraph(
|
|
graph: torch.fx.Graph,
|
|
region: Region,
|
|
get_subgraph_node: Node,
|
|
node_ind_arg_ind: Iterable[tuple[int, int]],
|
|
inds_with_external_users: list[int],
|
|
sub_gm: torch.fx.GraphModule,
|
|
subgraph_name: str,
|
|
output_replacements: dict[Node, Node],
|
|
) -> None:
|
|
sub_args = []
|
|
for node_ind, arg_ind in node_ind_arg_ind:
|
|
node = region[node_ind]
|
|
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
|
|
sub_args.append(flattened_args_kwargs[arg_ind])
|
|
|
|
invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args))
|
|
fake_inputs = [node.meta["example_value"] for node in sub_args]
|
|
|
|
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
|
|
log.debug(
|
|
"NYI: Failed to substitute region %s due to input alias or mutation",
|
|
region,
|
|
)
|
|
return
|
|
|
|
latest_region_node = region[-1]
|
|
with graph.inserting_after(latest_region_node):
|
|
invoke_subgraph_node = graph.create_node(
|
|
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
|
|
)
|
|
with graph.inserting_after(invoke_subgraph_node):
|
|
for ind, external_user_ind in enumerate(inds_with_external_users):
|
|
node = region[external_user_ind]
|
|
subgraph_output = graph.create_node(
|
|
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
|
|
)
|
|
output_replacements[node] = subgraph_output
|
|
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
|
|
|
|
# Erase in reverse topological order
|
|
for node in reversed(region):
|
|
graph.erase_node(node)
|
|
|
|
|
|
def _get_external_inputs(
|
|
region: Region,
|
|
) -> dict[Node, tuple[int, int]]:
|
|
external_node_to_indices = dict()
|
|
region_unique = set(region)
|
|
for node_ind, node in enumerate(region):
|
|
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
|
|
for arg_ind, in_node in enumerate(flattened_args_kwargs):
|
|
if (
|
|
isinstance(in_node, Node)
|
|
and in_node not in region_unique
|
|
and in_node not in external_node_to_indices
|
|
):
|
|
external_node_to_indices[in_node] = (node_ind, arg_ind)
|
|
|
|
return external_node_to_indices
|
|
|
|
|
|
def _get_all_output_indices(regions: list[Region]) -> list[int]:
|
|
# Scan all regions to get the set of all possible output nodes indices in the region
|
|
# perhaps we can record this information during region creation for more efficiency?
|
|
inds_with_external_users: set[int] = set()
|
|
for region in regions:
|
|
_get_inds_with_external_users(region, inds_with_external_users)
|
|
|
|
return sorted(inds_with_external_users)
|
|
|
|
|
|
def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None:
|
|
for ind, node in enumerate(region):
|
|
for user in node.users:
|
|
if user not in region:
|
|
if ind not in inds_unique:
|
|
inds_unique.add(ind)
|
|
|
|
|
|
def _copy_nodes_and_remap_inputs(
|
|
subgraph: torch.fx.Graph, region: Region
|
|
) -> dict[tuple[int, int], Any]:
|
|
external_inputs_to_indices = _get_external_inputs(region)
|
|
indices_to_placeholder_ind: dict[tuple[int, int], Any] = {}
|
|
region_to_subgraph_node = {}
|
|
for node in external_inputs_to_indices.keys():
|
|
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
|
region_to_subgraph_node[node] = placeholder
|
|
arg_indices = external_inputs_to_indices[node]
|
|
# Note: insertion order matches the order in which placeholders were created
|
|
# for the calling convention of the subgraph
|
|
indices_to_placeholder_ind[arg_indices] = None
|
|
|
|
def map_arg(node: Node) -> Node:
|
|
if node in region_to_subgraph_node:
|
|
return region_to_subgraph_node[node]
|
|
else:
|
|
return node
|
|
|
|
for node in region:
|
|
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
|
|
region_to_subgraph_node[node] = subgraph_node
|
|
|
|
return indices_to_placeholder_ind
|
|
|
|
|
|
def _create_subgraph_outputs(
|
|
subgraph: torch.fx.Graph, inds_to_output: list[int]
|
|
) -> None:
|
|
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
|
|
out_tup = tuple(node_list[ind] for ind in inds_to_output)
|
|
subgraph.output(out_tup)
|
|
|
|
|
|
def _create_subgraph(
|
|
region: Region,
|
|
inds_with_external_users: list[int],
|
|
) -> tuple[torch.fx.Graph, dict[tuple[int, int], Any]]:
|
|
subgraph: torch.fx.Graph = torch.fx.Graph()
|
|
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
|
|
_create_subgraph_outputs(subgraph, inds_with_external_users)
|
|
return subgraph, node_ind_input_inds
|