team-10/env/Lib/site-packages/torch/distributed/pipelining/_unflatten.py
2025-08-02 07:34:44 +02:00

30 lines
952 B
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
from collections import defaultdict
import torch
from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry
def _outline_submodules(orig_graph: torch.fx.Graph) -> torch.fx.GraphModule:
# Create an empty GraphModule to hold the outlined modules
new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
seen_nodes: dict[str, torch.fx.Node] = {}
seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
seen_attrs: dict[str, set[str]] = defaultdict(set)
created_modules: dict[str, torch.nn.Module] = {}
_ModuleFrame(
orig_graph,
tuple(orig_graph.nodes),
seen_nodes,
seen_modules,
seen_attrs,
created_modules,
None,
[("", None, 0)],
"",
{},
module=new_module,
).run_outer()
new_module.graph.lint()
new_module.recompile()
return new_module