100 lines
3.6 KiB
Python
100 lines
3.6 KiB
Python
from concurrent.futures import Future
|
|
from typing import Any, Optional
|
|
|
|
import torch.distributed as dist
|
|
import torch.distributed.checkpoint.state_dict_loader as loader
|
|
import torch.distributed.checkpoint.state_dict_saver as saver
|
|
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
|
|
from torch.distributed.checkpoint.storage import (
|
|
LoadPlanner,
|
|
SavePlanner,
|
|
StorageReader,
|
|
StorageWriter,
|
|
)
|
|
|
|
|
|
__all__: list[str] = []
|
|
|
|
|
|
class _Checkpointer:
|
|
"""This base class specefies a high level API for saving and loading
|
|
distributed `state_dict` 's. It provides an abstraction over the low-level APIs
|
|
provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling
|
|
:py:meth: `torch.distributed.state_dict_saver.save` and
|
|
:py:meth: `torch.distributed.state_dict_loader.load` with the provided storage
|
|
readers and writers.
|
|
|
|
.. warning::
|
|
This feature is experimental and subject to removal/change.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
storage_writer: StorageWriter,
|
|
storage_reader: StorageReader,
|
|
*,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
coordinator_rank: int = 0,
|
|
no_dist: bool = False,
|
|
load_planner: Optional[LoadPlanner] = None,
|
|
save_planner: Optional[SavePlanner] = None,
|
|
):
|
|
"""Initializes the Checkpointer instance.
|
|
|
|
Args:
|
|
storage_writer: Instance of StorageWrite use to perform writes.
|
|
storage_reader: StorageReader used to load data from.
|
|
process_group: ProcessGroup to be used for cross-rank synchronization.
|
|
coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default.
|
|
no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``)
|
|
loader_planner: Instance of LoadPlanner to use when loading.
|
|
save_planner: Instance of SavePlanner to use when saving.
|
|
"""
|
|
self.storage_writer = storage_writer
|
|
self.storage_reader = storage_reader
|
|
self.process_group = process_group
|
|
self.coordinator_rank = coordinator_rank
|
|
self.no_dist = no_dist
|
|
self.load_planner = load_planner
|
|
self.save_planner = save_planner
|
|
|
|
def save(
|
|
self,
|
|
state_dict: STATE_DICT_TYPE,
|
|
) -> Metadata:
|
|
"""Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization."""
|
|
return saver.save(
|
|
state_dict,
|
|
self.storage_writer,
|
|
process_group=self.process_group,
|
|
coordinator_rank=self.coordinator_rank,
|
|
no_dist=self.no_dist,
|
|
planner=self.save_planner,
|
|
)
|
|
|
|
def async_save(
|
|
self,
|
|
state_dict: STATE_DICT_TYPE,
|
|
) -> Future:
|
|
"""
|
|
Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization.
|
|
|
|
Returns:
|
|
Future: A future holding the resultant Metadata object from `save`.
|
|
"""
|
|
return saver.async_save(
|
|
state_dict,
|
|
storage_writer=self.storage_writer,
|
|
process_group=self.process_group,
|
|
planner=self.save_planner,
|
|
)
|
|
|
|
def load(self, state_dict: dict[str, Any]) -> None:
|
|
"""Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization."""
|
|
loader.load(
|
|
state_dict,
|
|
storage_reader=self.storage_reader,
|
|
process_group=self.process_group,
|
|
planner=self.load_planner,
|
|
)
|