team-10/env/Lib/site-packages/torch/distributed/pipelining/_utils.py
2025-08-02 07:34:44 +02:00

133 lines
3.7 KiB
Python

# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from dataclasses import dataclass
from typing import Union
import torch
from torch import fx
logger = logging.getLogger(__name__)
def flatten_args_detach(args):
"""
Flatten the args into a list form and detach the tensors from computational graph.
"""
flat_detached_args = []
def extract_tensor_args(a):
nonlocal flat_detached_args
if isinstance(a, torch.Tensor):
val = a.detach().requires_grad_(a.requires_grad)
flat_detached_args.append(val)
return val
else:
flat_detached_args.append(a)
return a
new_args = fx.node.map_aggregate(
args,
extract_tensor_args,
)
return new_args, flat_detached_args
def flatten_args(args):
"""
Flatten the args into a list form.
"""
flat_args = []
def extract_tensor_args(a):
nonlocal flat_args
flat_args.append(a)
return a
fx.node.map_aggregate(
args,
extract_tensor_args,
)
return flat_args
class PipeliningShapeError(RuntimeError):
"""Shape mismatch between configured and runtime values."""
def validate_tensor_metadata(desc, expected, given):
if not expected.shape == given.shape:
raise PipeliningShapeError(
f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
)
if not expected.dtype == given.dtype:
raise PipeliningShapeError(
f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
)
if not expected.stride() == given.stride():
raise PipeliningShapeError(
f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
)
def validate_tensors_metadata(
desc,
expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
):
if len(expected_tensors) != len(actual_tensors):
raise PipeliningShapeError(
f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
)
for i in range(len(expected_tensors)):
validate_tensor_metadata(
f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
)
def generate_stage_to_rank_mapping(
pp_size: int, num_stages: int, style: str = "loop"
) -> dict[int, int]:
"""
Compute the stage id to rank mapping for either a looped or V-style schedule.
Most commonly num_stages == pp_size * 2, but this function can be used to
compute the mapping for any number of stages per rank.
"""
mapping = {}
if style == "loop":
for stage_index in range(num_stages):
mapping[stage_index] = stage_index % pp_size
elif style == "v":
if num_stages % pp_size != 0:
raise ValueError(
f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules"
)
rank_index = 0
for stage_index in range(num_stages):
mapping[stage_index] = rank_index
# dont change rank if we are on the border (to keep v shape)
if (stage_index + 1) % pp_size == 0:
continue
if (stage_index // pp_size) % 2 == 0:
rank_index += 1
else:
rank_index -= 1
else:
raise ValueError(f"Style {style} is not supported.")
return mapping
@dataclass
class PipeInfo:
"""
Captures information for a pipeline (`Pipe` object).
"""
graph: fx.Graph
num_stages: int
has_loss_and_backward: bool