Adding all project files
This commit is contained in:
parent
6c9e127bdc
commit
cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions
285
venv/Lib/site-packages/torch/distributed/rpc/internal.py
Normal file
285
venv/Lib/site-packages/torch/distributed/rpc/internal.py
Normal file
|
@ -0,0 +1,285 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
import copyreg
|
||||
import io
|
||||
import pickle
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._C._distributed_rpc import _get_current_rpc_agent
|
||||
|
||||
|
||||
__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"]
|
||||
|
||||
# Thread local tensor tables to store tensors while pickling torch.Tensor
|
||||
# objects
|
||||
_thread_local_tensor_tables = threading.local()
|
||||
_pickler = pickle.Pickler
|
||||
_unpickler = pickle.Unpickler
|
||||
|
||||
|
||||
class RPCExecMode(Enum):
|
||||
SYNC = "sync"
|
||||
ASYNC = "async"
|
||||
ASYNC_JIT = "async_jit"
|
||||
REMOTE = "remote"
|
||||
|
||||
|
||||
class _InternalRPCPickler:
|
||||
r"""
|
||||
This class provides serialize() and deserialize() interfaces to serialize
|
||||
data to be "binary string + tensor table" format
|
||||
So for RPC python UDF function and args, non tensor data will be serialized
|
||||
into regular binary string, tensor data will be put into thread local tensor
|
||||
tables, this serialization format is consistent with builtin operator and args
|
||||
using JIT pickler. This format will make tensor handling in C++ much easier,
|
||||
e.g. attach tensor to distributed autograd graph in C++
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Ignore type error because dispatch_table is defined in third-party package
|
||||
self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined]
|
||||
self._dispatch_table[torch.Tensor] = self._tensor_reducer
|
||||
# Used for registering customized picklers.
|
||||
self._class_reducer_dict = {}
|
||||
|
||||
def _register_reducer(self, obj_class, reducer):
|
||||
# For the same class, only register the reducer once.
|
||||
if obj_class not in self._class_reducer_dict:
|
||||
self._class_reducer_dict[obj_class] = reducer
|
||||
|
||||
@classmethod
|
||||
def _tensor_receiver(cls, tensor_index):
|
||||
global _thread_local_tensor_tables
|
||||
return _thread_local_tensor_tables.recv_tables[tensor_index]
|
||||
|
||||
def _tensor_reducer(self, tensor):
|
||||
global _thread_local_tensor_tables
|
||||
_thread_local_tensor_tables.send_tables.append(tensor)
|
||||
tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
|
||||
return (_InternalRPCPickler._tensor_receiver, (tensor_index,))
|
||||
|
||||
@classmethod
|
||||
def _py_rref_receiver(cls, rref_fork_data):
|
||||
return dist.rpc.PyRRef._deserialize(rref_fork_data)
|
||||
|
||||
def _py_rref_reducer(self, py_rref):
|
||||
rref_fork_data = py_rref._serialize()
|
||||
return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,))
|
||||
|
||||
def _rref_reducer(self, rref):
|
||||
return self._py_rref_reducer(rref)
|
||||
|
||||
@classmethod
|
||||
def _script_module_receiver(cls, script_module_serialized):
|
||||
"""
|
||||
Given a serialized representation of a ScriptModule created with torch.jit.save,
|
||||
loads and returns the ScriptModule.
|
||||
"""
|
||||
f = io.BytesIO(script_module_serialized)
|
||||
m = torch.jit.load(f)
|
||||
return m
|
||||
|
||||
def _script_module_reducer(self, script_module):
|
||||
"""
|
||||
Serializes a ScriptModule.
|
||||
"""
|
||||
f = io.BytesIO()
|
||||
torch.jit.save(script_module, f)
|
||||
return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),))
|
||||
|
||||
def serialize(self, obj):
|
||||
r"""
|
||||
Serialize non tensor data into binary string, tensor data into
|
||||
tensor table
|
||||
"""
|
||||
f = io.BytesIO()
|
||||
p = _pickler(f)
|
||||
p.dispatch_table = self._dispatch_table
|
||||
|
||||
# rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref,
|
||||
# user picklers could have different initialization function from _InternalRPCPickler,
|
||||
# but all the user picklers should call serialize() and use _rref_reducer to pickle rref
|
||||
# in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not
|
||||
# compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor,
|
||||
# so putting rref's dispatch table here
|
||||
#
|
||||
# The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
|
||||
# The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
|
||||
# Ignore type error because dispatch_table is defined in third-party package
|
||||
p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index]
|
||||
# An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
|
||||
# Ignore type error because dispatch_table is defined in third-party package
|
||||
p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index]
|
||||
|
||||
# Add dispatch pickling for ScriptModule or its subclass.
|
||||
if isinstance(obj, torch.jit.ScriptModule):
|
||||
# Ignore type error because dispatch_table is defined in third-party package
|
||||
p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index]
|
||||
|
||||
# Install customized picklers.
|
||||
for class_name in self._class_reducer_dict.keys():
|
||||
p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index]
|
||||
|
||||
# save _thread_local_tensor_tables.send_tables if it is in nested call
|
||||
global _thread_local_tensor_tables
|
||||
if hasattr(_thread_local_tensor_tables, "send_tables"):
|
||||
old_send_tables = _thread_local_tensor_tables.send_tables
|
||||
else:
|
||||
old_send_tables = None
|
||||
_thread_local_tensor_tables.send_tables = []
|
||||
|
||||
p.dump(obj)
|
||||
|
||||
# restore _thread_local_tensor_tables.send_tables if return
|
||||
# from nested call, otherwise clean up the table
|
||||
tensors = _thread_local_tensor_tables.send_tables
|
||||
if old_send_tables is not None:
|
||||
_thread_local_tensor_tables.send_tables = old_send_tables
|
||||
else:
|
||||
del _thread_local_tensor_tables.send_tables
|
||||
|
||||
return (f.getvalue(), tensors)
|
||||
|
||||
def deserialize(self, binary_data, tensor_table):
|
||||
r"""
|
||||
Deserialize binary string + tensor table to original obj
|
||||
"""
|
||||
# save _thread_local_tensor_tables.recv_tables if it is in nested call
|
||||
global _thread_local_tensor_tables
|
||||
if hasattr(_thread_local_tensor_tables, "recv_tables"):
|
||||
old_recv_tables = _thread_local_tensor_tables.recv_tables
|
||||
else:
|
||||
old_recv_tables = None
|
||||
_thread_local_tensor_tables.recv_tables = tensor_table
|
||||
|
||||
try:
|
||||
unpickler = _unpickler(io.BytesIO(binary_data))
|
||||
ret = unpickler.load()
|
||||
except AttributeError as e:
|
||||
# Occurs when function is not found on module/class during
|
||||
# unpickling.
|
||||
except_str = (
|
||||
str(e)
|
||||
+ """ Default RPC pickler does not serialize
|
||||
function code. Ensure that UDFs are defined on both caller and
|
||||
callee modules."""
|
||||
)
|
||||
ret = AttributeError(except_str)
|
||||
# Ensure the stack trace gets preserved
|
||||
ret.__cause__ = e
|
||||
|
||||
# restore _thread_local_tensor_tables.recv_tables if return
|
||||
# from nested call, otherwise clean up the table
|
||||
if old_recv_tables is not None:
|
||||
_thread_local_tensor_tables.recv_tables = old_recv_tables
|
||||
else:
|
||||
del _thread_local_tensor_tables.recv_tables
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
|
||||
_internal_rpc_pickler = _InternalRPCPickler()
|
||||
|
||||
|
||||
def serialize(obj):
|
||||
return _internal_rpc_pickler.serialize(obj)
|
||||
|
||||
|
||||
def deserialize(binary_data, tensor_table):
|
||||
return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
|
||||
|
||||
|
||||
def _run_function(python_udf):
|
||||
r"""
|
||||
This function is exclusively called from C++.
|
||||
See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
|
||||
|
||||
Runs a Python UDF and returns its return value.
|
||||
Wraps any exception in ``RemoteException`` if the function raises.
|
||||
"""
|
||||
try:
|
||||
if isinstance(python_udf, AttributeError):
|
||||
raise python_udf
|
||||
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
|
||||
except Exception as e:
|
||||
# except str = exception info + traceback string
|
||||
except_str = (
|
||||
f"On {_get_current_rpc_agent().get_worker_info()}:\n"
|
||||
f"{repr(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
print(except_str, file=sys.stderr)
|
||||
result = RemoteException(except_str, type(e))
|
||||
return result
|
||||
|
||||
|
||||
def _handle_exception(result):
|
||||
if isinstance(result, RemoteException):
|
||||
exception_msg = result.msg.encode("utf-8").decode("unicode_escape")
|
||||
# We wrap exception re-creation here in case some exception classes
|
||||
# cannot be constructed directly from a string.
|
||||
exc = None
|
||||
try:
|
||||
exc = result.exception_type(exception_msg)
|
||||
except BaseException as e:
|
||||
raise RuntimeError( # noqa: B904
|
||||
f"Failed to create original exception type. Error msg was {str(e)}"
|
||||
f" Original exception on remote side was {exception_msg}"
|
||||
) from e
|
||||
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
|
||||
def _build_rpc_profiling_key(
|
||||
exec_type, func_name, current_worker_name, dst_worker_name
|
||||
):
|
||||
"""
|
||||
Builds the key that RPC calls are profiled with using the autograd profiler.
|
||||
This will be the name of the corresponding Event recorded in the profiler.
|
||||
|
||||
Args:
|
||||
exec_type (RPCExecMode): Type of RPC/RRef call
|
||||
func_name (str): Name of function being profiled.
|
||||
current_worker_name (str): Name of current worker.
|
||||
dst_worker_name (str): Name of the destination worker.
|
||||
|
||||
Returns:
|
||||
String representing profiling key
|
||||
"""
|
||||
profile_key = (
|
||||
f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})"
|
||||
)
|
||||
return profile_key
|
||||
|
||||
|
||||
def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
|
||||
"""
|
||||
This function should be called from RPC/RRef functions to create a
|
||||
RecordFunction object for profiling. This function also runs the before
|
||||
callbacks that start the profiling, though the user is responsible for
|
||||
running the appropriate callbacks when the function to be profiled finishes.
|
||||
|
||||
Args:
|
||||
exec_type (RPCExecMode): Type of RPC/RRef call
|
||||
func_name (str): Name of function being profiled.
|
||||
current_worker_name (str): Name of current worker.
|
||||
dest_worker_name (str): Name of the destination worker.
|
||||
|
||||
Returns:
|
||||
An instance of `torch.autograd._RecordFunction`.
|
||||
"""
|
||||
assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled."
|
||||
profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})"
|
||||
rf = torch.autograd._RecordFunction() # type: ignore[attr-defined]
|
||||
torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined]
|
||||
return rf
|
||||
|
||||
|
||||
PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
|
||||
RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"])
|
Loading…
Add table
Add a link
Reference in a new issue