111 lines
4.1 KiB
Python
111 lines
4.1 KiB
Python
import re
|
|
from typing import Any, Callable, Union
|
|
|
|
import torch
|
|
from torch.utils._pytree import tree_flatten_with_path, tree_map
|
|
|
|
|
|
KeyPath = tuple[Any, ...]
|
|
NonTensorShapeFn = Callable[[Union[int, float]], tuple[Any, ...]]
|
|
|
|
__all__ = [
|
|
"normalize_source_name",
|
|
"module_to_nested_dict",
|
|
"track_dynamism_across_examples",
|
|
"clone_and_convert_to_meta",
|
|
]
|
|
|
|
|
|
def normalize_source_name(name: str) -> str:
|
|
# Match attribute access like .x and replace with ['x']
|
|
return re.sub(r"\.([a-zA-Z_][a-zA-Z0-9_]*)", r"['\1']", name)
|
|
|
|
|
|
def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]:
|
|
"""Recursively converts an nn.Module into a nested dictionary with explicit 'parameters' and 'modules' keys."""
|
|
self_dict: dict[str, Any] = {}
|
|
|
|
self_dict["_parameters"] = {}
|
|
self_dict["_modules"] = {}
|
|
|
|
for attr_name in dir(module):
|
|
if not attr_name.startswith("_") and not callable(getattr(module, attr_name)):
|
|
attr_value = getattr(module, attr_name)
|
|
if (
|
|
not isinstance(attr_value, torch.nn.Module)
|
|
and isinstance(attr_value, (int, float, torch.Tensor))
|
|
and type(attr_value) is not bool
|
|
):
|
|
self_dict[attr_name] = attr_value
|
|
|
|
for name, param in module.named_parameters(recurse=False):
|
|
self_dict["_parameters"][name] = param
|
|
for name, buffer in module.named_buffers(recurse=False):
|
|
self_dict["_parameters"][name] = buffer
|
|
|
|
for name, submodule in module.named_children():
|
|
self_dict["_modules"][name] = module_to_nested_dict(submodule)
|
|
|
|
return self_dict
|
|
|
|
|
|
def track_dynamism_across_examples(
|
|
example_inputs: list[Any],
|
|
) -> dict[Any, Any]:
|
|
"""
|
|
This function analyzes a list of example inputs to determine the dynamism of their shapes.
|
|
It tracks whether the dimensions of tensors or non-tensor values change across
|
|
different examples. The function returns a dictionary where each key represents
|
|
a path to a value in the input examples, and the corresponding value is a tuple
|
|
indicating which dimensions are dynamic (i.e., change across examples). This
|
|
helps in understanding how the structure of data varies across different instances.
|
|
"""
|
|
tracking: dict[KeyPath, tuple[list[set[Any]], bool]] = {}
|
|
|
|
for ex in example_inputs:
|
|
if "self" in ex and isinstance(ex["self"], torch.nn.Module):
|
|
ex["self"] = module_to_nested_dict(ex["self"])
|
|
leaves_with_paths, _ = tree_flatten_with_path(ex)
|
|
for key_path, value in leaves_with_paths:
|
|
if not isinstance(value, (int, float, torch.Tensor)):
|
|
continue
|
|
if isinstance(value, torch.Tensor):
|
|
shape: tuple[int | float, ...] = tuple(value.shape)
|
|
is_tensor = True
|
|
else:
|
|
shape = (value,)
|
|
is_tensor = False
|
|
if key_path not in tracking:
|
|
tracking[key_path] = ([set() for _ in range(len(shape))], is_tensor)
|
|
else:
|
|
dim_sets, flag = tracking[key_path]
|
|
if flag != is_tensor:
|
|
pass
|
|
while len(dim_sets) < len(shape):
|
|
dim_sets.append(set())
|
|
for i, dim in enumerate(shape):
|
|
tracking[key_path][0][i].add(dim)
|
|
|
|
output: dict[Any, Any] = {}
|
|
for key_path, (dim_sets, _is_tensor) in tracking.items():
|
|
final_dyn = tuple(len(s) > 1 for s in dim_sets)
|
|
key_str = "L" + "".join(f"{str(k)}" for k in key_path)
|
|
key = key_path[0].key # type: ignore[attr-defined]
|
|
if key not in output:
|
|
output[key] = {}
|
|
output[key][key_str] = final_dyn
|
|
return output
|
|
|
|
|
|
def clone_and_convert_to_meta(example_input: Any) -> Any:
|
|
"""
|
|
This function takes a list of example inputs and for each tensor, clones it and converts it to device=meta.
|
|
For non-tensor values, it keeps the reference. It uses pytree to handle nested structures recursively.
|
|
"""
|
|
|
|
def transform_fn(value: Any) -> Any:
|
|
if isinstance(value, torch.Tensor):
|
|
return value.clone().to(device="meta")
|
|
return value
|
|
|
|
return tree_map(transform_fn, example_input)
|