team-10/env/Lib/site-packages/torch/utils/data/graph_settings.py
2025-08-02 07:34:44 +02:00

174 lines
5.4 KiB
Python

# mypy: allow-untyped-defs
import inspect
import warnings
from typing import Any, Optional
from typing_extensions import deprecated
import torch
from torch.utils.data.datapipes.iter.sharding import (
_ShardingIterDataPipe,
SHARDING_PRIORITIES,
)
from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
__all__ = [
"apply_random_seed",
"apply_sharding",
"apply_shuffle_seed",
"apply_shuffle_settings",
"get_all_graph_pipes",
]
def get_all_graph_pipes(graph: DataPipeGraph) -> list[DataPipe]:
return _get_all_graph_pipes_helper(graph, set())
def _get_all_graph_pipes_helper(
graph: DataPipeGraph, id_cache: set[int]
) -> list[DataPipe]:
results: list[DataPipe] = []
for dp_id, (datapipe, sub_graph) in graph.items():
if dp_id in id_cache:
continue
id_cache.add(dp_id)
results.append(datapipe)
results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache))
return results
def _is_sharding_datapipe(datapipe: DataPipe) -> bool:
return isinstance(datapipe, _ShardingIterDataPipe) or (
hasattr(datapipe, "apply_sharding")
and inspect.ismethod(datapipe.apply_sharding)
)
def apply_sharding(
datapipe: DataPipe,
num_of_instances: int,
instance_id: int,
sharding_group=SHARDING_PRIORITIES.DEFAULT,
) -> DataPipe:
r"""
Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``.
RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch.
"""
graph = traverse_dps(datapipe)
def _helper(graph, prev_applied=None):
for dp, sub_graph in graph.values():
applied = None
if _is_sharding_datapipe(dp):
if prev_applied is not None:
raise RuntimeError(
"Sharding twice on a single pipeline is likely unintended and will cause data loss. "
f"Sharding already applied to {prev_applied} while trying to apply to {dp}"
)
# For BC, only provide sharding_group if accepted
sig = inspect.signature(dp.apply_sharding)
if len(sig.parameters) < 3:
dp.apply_sharding(num_of_instances, instance_id)
else:
dp.apply_sharding(
num_of_instances, instance_id, sharding_group=sharding_group
)
applied = dp
if applied is None:
applied = prev_applied
_helper(sub_graph, applied)
_helper(graph)
return datapipe
def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
return (
hasattr(datapipe, "set_shuffle")
and hasattr(datapipe, "set_seed")
and inspect.ismethod(datapipe.set_shuffle)
and inspect.ismethod(datapipe.set_seed)
)
def apply_shuffle_settings(
datapipe: DataPipe, shuffle: Optional[bool] = None
) -> DataPipe:
r"""
Traverse the graph of ``DataPipes`` to find and set shuffle attribute.
Apply the method to each `DataPipe` that has APIs of ``set_shuffle``
and ``set_seed``.
Args:
datapipe: DataPipe that needs to set shuffle attribute
shuffle: Shuffle option (default: ``None`` and no-op to the graph)
"""
if shuffle is None:
return datapipe
graph = traverse_dps(datapipe)
all_pipes = get_all_graph_pipes(graph)
shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)]
if not shufflers and shuffle:
warnings.warn(
"`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
"Be aware that the default buffer size might not be sufficient for your task."
)
datapipe = datapipe.shuffle()
shufflers = [
datapipe,
]
for shuffler in shufflers:
shuffler.set_shuffle(shuffle)
return datapipe
@deprecated(
"`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. "
"Please use `apply_random_seed` instead.",
category=FutureWarning,
)
def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
return apply_random_seed(datapipe, rng)
def _is_random_datapipe(datapipe: DataPipe) -> bool:
return hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed)
def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
r"""
Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``.
Then set the random seed based on the provided RNG to those ``DataPipe``.
Args:
datapipe: DataPipe that needs to set randomness
rng: Random number generator to generate random seeds
"""
graph = traverse_dps(datapipe)
all_pipes = get_all_graph_pipes(graph)
# Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once.
# And, `id` is used in case of unhashable DataPipe
cache = set()
random_datapipes = []
for pipe in all_pipes:
if id(pipe) in cache:
continue
if _is_random_datapipe(pipe):
random_datapipes.append(pipe)
cache.add(id(pipe))
for pipe in random_datapipes:
random_seed = int(
torch.empty((), dtype=torch.int64).random_(generator=rng).item()
)
pipe.set_seed(random_seed)
return datapipe