# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging from dataclasses import dataclass from typing import Union import torch from torch import fx logger = logging.getLogger(__name__) def flatten_args_detach(args): """ Flatten the args into a list form and detach the tensors from computational graph. """ flat_detached_args = [] def extract_tensor_args(a): nonlocal flat_detached_args if isinstance(a, torch.Tensor): val = a.detach().requires_grad_(a.requires_grad) flat_detached_args.append(val) return val else: flat_detached_args.append(a) return a new_args = fx.node.map_aggregate( args, extract_tensor_args, ) return new_args, flat_detached_args def flatten_args(args): """ Flatten the args into a list form. """ flat_args = [] def extract_tensor_args(a): nonlocal flat_args flat_args.append(a) return a fx.node.map_aggregate( args, extract_tensor_args, ) return flat_args class PipeliningShapeError(RuntimeError): """Shape mismatch between configured and runtime values.""" def validate_tensor_metadata(desc, expected, given): if not expected.shape == given.shape: raise PipeliningShapeError( f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}" ) if not expected.dtype == given.dtype: raise PipeliningShapeError( f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}" ) if not expected.stride() == given.stride(): raise PipeliningShapeError( f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}" ) def validate_tensors_metadata( desc, expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], ): if len(expected_tensors) != len(actual_tensors): raise PipeliningShapeError( f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})" ) for i in range(len(expected_tensors)): validate_tensor_metadata( f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] ) def generate_stage_to_rank_mapping( pp_size: int, num_stages: int, style: str = "loop" ) -> dict[int, int]: """ Compute the stage id to rank mapping for either a looped or V-style schedule. Most commonly num_stages == pp_size * 2, but this function can be used to compute the mapping for any number of stages per rank. """ mapping = {} if style == "loop": for stage_index in range(num_stages): mapping[stage_index] = stage_index % pp_size elif style == "v": if num_stages % pp_size != 0: raise ValueError( f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules" ) rank_index = 0 for stage_index in range(num_stages): mapping[stage_index] = rank_index # dont change rank if we are on the border (to keep v shape) if (stage_index + 1) % pp_size == 0: continue if (stage_index // pp_size) % 2 == 0: rank_index += 1 else: rank_index -= 1 else: raise ValueError(f"Style {style} is not supported.") return mapping @dataclass class PipeInfo: """ Captures information for a pipeline (`Pipe` object). """ graph: fx.Graph num_stages: int has_loss_and_backward: bool