team-10/venv/Lib/site-packages/torch/distributed/checkpoint/planner_helpers.py
2025-08-02 02:00:33 +02:00

478 lines
16 KiB
Python

# mypy: allow-untyped-defs
import io
from typing import Any, Callable, cast
import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from .metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
MetadataIndex,
STATE_DICT_TYPE,
STORAGE_TYPES,
TensorProperties,
TensorStorageMetadata,
)
from .planner import (
LoadItemType,
ReadItem,
SavePlan,
TensorWriteData,
WriteItem,
WriteItemType,
)
from .resharding import (
_check_shard_metadata_pair_overlap,
_shards_get_overlap_region_wrt_saved_tensor,
)
__all__: list[str] = ["create_read_items_for_chunk_list"]
def _compare_save_plans(plan: SavePlan, other_plan: SavePlan) -> bool:
"""
Compare the two Save plans and return True if they are equal.
Args:
plan (SavePlan): First SavePlan to compare.
other_plan (SavePlan): Second SavePlan to compare.
Returns:
True if the two plans are equal, False otherwise.
"""
if plan.usable != other_plan.usable:
return False
# Both the plans should have the same number of items
if len(plan.items) != len(other_plan.items):
return False
# Both the plans should have the same write items.
for plan_item, other_plan_item in zip(plan.items, other_plan.items):
# Write item type should be same
if plan_item.type != other_plan_item.type:
return False
plan_metadata_index = plan_item.index
other_plan_metadata_index = other_plan_item.index
# Write item metadata_index should be same
if (
plan_metadata_index.fqn != other_plan_metadata_index.fqn
or plan_metadata_index.offset != other_plan_metadata_index.offset
or plan_metadata_index.index != other_plan_metadata_index.index
):
return False
# Write item tensor_data should be present in both the write items plans, if it exists in either of them.
tensor_data = plan_item.tensor_data
other_tensor_data = other_plan_item.tensor_data
if (tensor_data and not other_tensor_data) or (
not tensor_data and other_tensor_data
):
return False
if tensor_data and other_tensor_data:
# Write item tensor_data size should be same
if tensor_data.size != other_tensor_data.size:
return False
# Write item tensor_data chunk should be present in both the write items, if it exists in either of them.
chunk = tensor_data.chunk
other_chunk = other_tensor_data.chunk
if (chunk and not other_chunk) or (not chunk and other_chunk):
return False
# Write item tensor_data chunk offsets and sizes should be same
if chunk and other_chunk:
if (
chunk.offsets != other_chunk.offsets
or chunk.sizes != other_chunk.sizes
):
return False
return True
def _merge_delta_local_plans(
cached_plans: list[SavePlan],
delta_plans: list[SavePlan],
) -> list[SavePlan]:
"""
Merge a list of delta plans into a single plan.
Args:
cached_plans (List[SavePlan]): A list of cached plans.
delta_plans (List[SavePlan]): A list of delta plans to merge. It can contain empty plans
Returns:
A single merged plan. If a delta plan is not usable, use the cached plan. Otherwise, use the delta plan.
"""
merged_plans = []
for cached_plan, delta_plan in zip(cached_plans, delta_plans):
if delta_plan and not delta_plan.usable:
merged_plans.append(cached_plan)
else:
merged_plans.append(delta_plan)
return merged_plans
def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
return ChunkStorageMetadata(
offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
)
def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
return ChunkStorageMetadata(
offsets=torch.Size(shard_md.shard_offsets),
sizes=torch.Size(shard_md.shard_sizes),
)
def _sharded_tensor_metadata(
sharded_tensor: ShardedTensor, shard_md: ShardMetadata
) -> TensorWriteData:
shard_properties = sharded_tensor.metadata().tensor_properties
properties = TensorProperties(
dtype=shard_properties.dtype,
layout=shard_properties.layout,
requires_grad=shard_properties.requires_grad,
memory_format=shard_properties.memory_format,
pin_memory=shard_properties.pin_memory,
)
return TensorWriteData(
chunk=_chunk_for_shard(shard_md),
properties=properties,
size=sharded_tensor.metadata().size,
)
def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
sizes, offsets = compute_local_shape_and_global_offset(
tensor.shape, tensor.device_mesh, tensor.placements
)
sizes, offsets = torch.Size(sizes), torch.Size(offsets)
return WriteItem(
index=MetadataIndex(fqn, offsets),
type=WriteItemType.SHARD,
tensor_data=TensorWriteData(
chunk=ChunkStorageMetadata(
offsets=offsets,
sizes=sizes,
),
properties=TensorProperties.create_from_tensor(tensor.to_local()),
size=tensor.size(),
),
)
def _create_write_item_for_shard(
fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
) -> WriteItem:
offsets = torch.Size(shard_md.shard_offsets)
return WriteItem(
index=MetadataIndex(fqn, offsets),
type=WriteItemType.SHARD,
tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
)
def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
offsets = torch.Size([0] * len(tensor.size()))
return WriteItem(
index=MetadataIndex(fqn, offsets),
type=WriteItemType.TENSOR,
tensor_data=TensorWriteData(
chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
properties=TensorProperties.create_from_tensor(tensor),
size=tensor.size(),
),
)
def _create_write_item_for_bytesio(fqn: str, bytes: Any):
return WriteItem(
index=MetadataIndex(fqn),
type=WriteItemType.BYTE_IO,
)
def _create_read_item_for_byteio(
dest_index, dest_offset, storage_index, storage_offset, length
):
return ReadItem(
type=LoadItemType.BYTE_IO,
dest_index=dest_index,
dest_offsets=torch.Size((dest_offset,)),
storage_index=storage_index,
storage_offsets=torch.Size((storage_offset,)),
lengths=torch.Size((length,)),
)
def _create_read_item_for_tensor(
dest_index, dest_offsets, storage_index, storage_offsets, lengths
):
return ReadItem(
type=LoadItemType.TENSOR,
dest_index=dest_index,
dest_offsets=torch.Size(dest_offsets),
storage_index=storage_index,
storage_offsets=torch.Size(storage_offsets),
lengths=torch.Size(lengths),
)
def create_read_items_for_chunk_list(
fqn: str,
checkpoint_md: TensorStorageMetadata,
local_chunks: list[ChunkStorageMetadata],
) -> list[ReadItem]:
"""
Create a list of ``ReadItem`` based on the checkpoint and local chunks.
This applies the resharding algorithm and computes the reads needed
to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
Args:
fqn (str) : The state_dict FQN to pass to ``ReadItem``.
checkpoint_md (TensorStorageMetadata): metadata for a given tensor
from a checkpoint.
local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
loaded.
Returns:
A list of ``ReadItem`` that will satisfy all input chunks.
"""
read_items = []
# this is a naive quadratic algo that can be optimized later
for idx, shard in enumerate(local_chunks):
for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
if not _check_shard_metadata_pair_overlap(shard, storage_md):
continue
storage_offsets = []
dest_offsets = []
lengths = []
for (
_dim,
offset_for_saved_tensor,
offset_for_current_tensor,
length,
) in _shards_get_overlap_region_wrt_saved_tensor(
saved_shard=storage_md, current_shard=shard
):
storage_offsets.append(offset_for_saved_tensor)
dest_offsets.append(offset_for_current_tensor)
lengths.append(length)
read_items.append(
_create_read_item_for_tensor(
dest_index=MetadataIndex(fqn, shard.offsets, idx),
dest_offsets=dest_offsets,
storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx),
storage_offsets=storage_offsets,
lengths=lengths,
)
)
return read_items
def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
requests = []
for fqn, obj in state_dict.items():
if isinstance(obj, DTensor):
requests.append(_create_write_items_for_dtensor(fqn, obj))
elif isinstance(obj, ShardedTensor):
requests.extend(
_create_write_item_for_shard(fqn, obj, shard_md)
for shard_md in obj.metadata().shards_metadata
)
elif isinstance(obj, torch.Tensor):
requests.append(_create_write_item_for_tensor(fqn, obj))
else:
requests.append(_create_write_item_for_bytesio(fqn, obj))
return SavePlan(requests)
def _create_write_items(fqn: str, object: Any) -> list[WriteItem]:
if hasattr(object, "__create_write_items__"):
# DTensor implements _Checkpointable
return object.__create_write_items__(fqn, object)
elif isinstance(object, ShardedTensor):
return [
_create_write_item_for_shard(fqn, object, shard.metadata)
for shard in object.local_shards()
]
elif isinstance(object, torch.Tensor):
return [_create_write_item_for_tensor(fqn, object)]
else:
return [_create_write_item_for_bytesio(fqn, object)]
def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
sizes, offsets = compute_local_shape_and_global_offset(
tensor.shape, tensor.device_mesh, tensor.placements
)
sizes, offsets = torch.Size(sizes), torch.Size(offsets)
return ChunkStorageMetadata(
offsets=offsets,
sizes=sizes,
)
def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]:
if hasattr(tensor, "__create_chunk_list__"):
# DTensor implements _Checkpointable
local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined]
elif isinstance(tensor, ShardedTensor):
local_chunks = [
_chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
]
elif isinstance(tensor, torch.Tensor):
local_chunks = [_create_chunk_from_tensor(tensor)]
else:
raise ValueError(
"Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
f",but got {type(tensor)}"
)
return local_chunks
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]:
if not isinstance(md, BytesStorageMetadata):
try:
local_chunks = _create_chunk_list(obj)
except ValueError as ex:
raise ValueError(
f"Invalid checkpoint metadata for {fqn}, "
+ f"expected BytesStorageMetadata but found {type(md)}",
) from ex
return create_read_items_for_chunk_list(fqn, md, local_chunks)
else:
return [
_create_read_item_for_byteio(
dest_index=MetadataIndex(fqn),
dest_offset=0,
storage_index=MetadataIndex(fqn),
storage_offset=0,
length=0,
)
]
def _init_state_dict(state_dict: dict[str, Any]) -> Any:
"""
Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
"""
def dtensor_func(value: DTensor):
device = getattr(value, "device", None)
if device == torch.device("meta"):
device_type = dist.distributed_c10d._get_pg_default_device().type
device = cast(
torch.device, _get_device_module(device_type).current_device()
)
new_local_tensor = torch.empty_like(value.to_local(), device=device)
# We need to pass shape and stride explicitly, since DTensor might be
# sharded unevenly.
dtensor = DTensor.from_local(
new_local_tensor,
device_mesh=value.device_mesh,
placements=value.placements,
shape=value.size(),
stride=value.stride(),
)
return dtensor
else:
return value
def sharded_tensor_func(value: Any):
device = getattr(value, "device", None)
if device == torch.device("meta"):
raise RuntimeError(
f"Found unsupported type {type(value)} for meta device loading."
)
else:
return value
def tensor_func(value: torch.Tensor):
device = getattr(value, "device", None)
if device == torch.device("meta"):
device_type = dist.distributed_c10d._get_pg_default_device().type
device = cast(
torch.device, _get_device_module(device_type).current_device()
)
tensor = torch.empty_like(value, device=device)
return tensor
else:
return value
_iterate_state_dict(
state_dict,
dtensor_func,
sharded_tensor_func,
tensor_func,
)
def _iterate_state_dict(
iter_object: Any,
dtensor_func: Callable,
sharded_tensor_func: Callable,
tensor_func: Callable,
):
"""
Iterate through the state dict, applying the given functions to each tensor type
and update the state dict in place.
Args:
iter_object (Any): the target state_dict.
sharded_tensor_func (Callable): the function to apply to ShardedTensor
dtensor_func (Callable): the function to apply to DTensor
tensor_func (Callable): the function to apply to Tensor
# TODO: let state_dict_util._iterate_state_dict() to support in place option
so we don't need to have two versions of _iterate_state_dict.
"""
if isinstance(iter_object, DTensor):
return dtensor_func(iter_object)
elif isinstance(iter_object, ShardedTensor):
return sharded_tensor_func(iter_object)
elif isinstance(iter_object, torch.Tensor):
return tensor_func(iter_object)
elif (
isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
or iter_object is None
):
return iter_object
elif isinstance(iter_object, dict):
for key, value in iter_object.items():
iter_object[key] = _iterate_state_dict(
value, dtensor_func, sharded_tensor_func, tensor_func
)
return iter_object
elif isinstance(iter_object, (list, tuple)):
ret = [
_iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func)
for v in iter_object
]
if isinstance(iter_object, tuple):
ret = tuple(ret) # type: ignore[assignment]
return ret