# Copyright (c) Meta Platforms, Inc. and affiliates from collections import defaultdict import torch from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry def _outline_submodules(orig_graph: torch.fx.Graph) -> torch.fx.GraphModule: # Create an empty GraphModule to hold the outlined modules new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) seen_nodes: dict[str, torch.fx.Node] = {} seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list) seen_attrs: dict[str, set[str]] = defaultdict(set) created_modules: dict[str, torch.nn.Module] = {} _ModuleFrame( orig_graph, tuple(orig_graph.nodes), seen_nodes, seen_modules, seen_attrs, created_modules, None, [("", None, 0)], "", {}, module=new_module, ).run_outer() new_module.graph.lint() new_module.recompile() return new_module