Adding all project files
This commit is contained in:
parent
6c9e127bdc
commit
cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions
42
venv/Lib/site-packages/torch/distributed/checkpoint/api.py
Normal file
42
venv/Lib/site-packages/torch/distributed/checkpoint/api.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import traceback as tb
|
||||
from typing import Any
|
||||
|
||||
|
||||
WRAPPED_EXCEPTION = tuple[BaseException, tb.StackSummary]
|
||||
|
||||
__all__ = ["CheckpointException"]
|
||||
|
||||
|
||||
def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
|
||||
return (exc, tb.extract_tb(exc.__traceback__))
|
||||
|
||||
|
||||
def _is_wrapped_exception(obj: Any) -> bool:
|
||||
if not isinstance(obj, tuple):
|
||||
return False
|
||||
if len(obj) != 2:
|
||||
return False
|
||||
return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
|
||||
|
||||
|
||||
class CheckpointException(BaseException):
|
||||
"""Exception raised if failure was detected as part of a checkpoint load or save."""
|
||||
|
||||
def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]):
|
||||
super().__init__(msg, failures)
|
||||
self._failures = failures
|
||||
|
||||
@property
|
||||
def failures(self) -> dict[int, WRAPPED_EXCEPTION]:
|
||||
"""Return a dictionary mapping node ranks to their associated exceptions in case of failure."""
|
||||
return self._failures
|
||||
|
||||
def __str__(self) -> str:
|
||||
str = f"CheckpointException ranks:{self._failures.keys()}\n"
|
||||
for rank, exc_pair in self._failures.items():
|
||||
exc, trace = exc_pair
|
||||
str += f"Traceback (most recent call last): (RANK {rank})\n"
|
||||
if trace is not None:
|
||||
str += "".join(tb.format_list(trace))
|
||||
str += "".join(tb.format_exception_only(type(exc), value=exc))
|
||||
return str
|
Loading…
Add table
Add a link
Reference in a new issue