185 lines
5.8 KiB
Python
185 lines
5.8 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# mypy: allow-untyped-defs
|
||
|
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the BSD-style license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
import datetime
|
||
|
import os
|
||
|
import socket
|
||
|
from contextlib import closing
|
||
|
from typing import Optional
|
||
|
|
||
|
import torch.distributed as dist
|
||
|
from torch.distributed.elastic.utils.logging import get_logger
|
||
|
from torch.distributed.elastic.utils.store import barrier
|
||
|
|
||
|
|
||
|
__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"]
|
||
|
|
||
|
logger = get_logger(__name__)
|
||
|
|
||
|
_ADDRESS_IN_USE = "Address already in use"
|
||
|
_SOCKET_TIMEOUT = "Socket Timeout"
|
||
|
|
||
|
_TCP_STORE_INIT = "_tcp_store/num_members"
|
||
|
|
||
|
|
||
|
def create_c10d_store(
|
||
|
is_server: bool,
|
||
|
server_addr: str,
|
||
|
server_port: int = -1,
|
||
|
world_size: int = 1,
|
||
|
timeout: float = (60 * 10), # 10 min
|
||
|
wait_for_workers: bool = True,
|
||
|
retries=3,
|
||
|
use_libuv: Optional[bool] = None,
|
||
|
):
|
||
|
if use_libuv is not None:
|
||
|
logger.warning(
|
||
|
"argument use_libuv is deprecated and ignored. Set USE_LIBUV environment "
|
||
|
'variable to "0" to disable libuv, or "1" to enable it. If the env var '
|
||
|
"is not set, libuv will be used by default."
|
||
|
)
|
||
|
|
||
|
# check os.environ for use_libuv
|
||
|
use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option
|
||
|
|
||
|
if server_port == -1 and world_size > 1:
|
||
|
raise ValueError(
|
||
|
f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}"
|
||
|
)
|
||
|
|
||
|
if server_port != -1:
|
||
|
logger.info("sever_port: %s, specified, ignoring retries", server_port)
|
||
|
|
||
|
# only retry when server_port is NOT static
|
||
|
attempt = retries if server_port == -1 else 1
|
||
|
while True:
|
||
|
if server_port != -1:
|
||
|
port = server_port
|
||
|
else:
|
||
|
port = get_free_port()
|
||
|
|
||
|
logger.info(
|
||
|
"Creating c10d store on %s:%s\n"
|
||
|
" world_size : %s\n"
|
||
|
" is_server : %s\n"
|
||
|
" timeout(sec): %s\n"
|
||
|
" use_libuv : %s\n",
|
||
|
server_addr,
|
||
|
port,
|
||
|
world_size,
|
||
|
is_server,
|
||
|
timeout,
|
||
|
use_libuv,
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
store = dist.TCPStore(
|
||
|
host_name=server_addr,
|
||
|
port=port,
|
||
|
world_size=world_size,
|
||
|
is_master=is_server,
|
||
|
timeout=datetime.timedelta(seconds=timeout),
|
||
|
wait_for_workers=wait_for_workers,
|
||
|
use_libuv=use_libuv,
|
||
|
)
|
||
|
# skips full rank check when we don't have to wait for all workers
|
||
|
if wait_for_workers:
|
||
|
_check_full_rank(store, world_size, timeout=timeout)
|
||
|
logger.info("Successfully created c10d store")
|
||
|
return store
|
||
|
except RuntimeError as e:
|
||
|
# this is brittle, but the underlying exception type is not properly pybinded
|
||
|
# so we parse the error msg for now, interestingly this is how torch itself
|
||
|
# detects timeouts and port conflicts in their own unittests
|
||
|
# see - caffe2/torch/testing/_internal/common_utils.py
|
||
|
# TODO properly map the exceptions in pybind (c10d/init.cpp)
|
||
|
if str(e) == _ADDRESS_IN_USE: # this will only happen on the server
|
||
|
if attempt < retries:
|
||
|
logger.warning(
|
||
|
"port: %s already in use, attempt: [%s/%s]",
|
||
|
port,
|
||
|
attempt,
|
||
|
retries,
|
||
|
)
|
||
|
attempt += 1
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
f"on {server_addr}, port: {port} already in use"
|
||
|
) from e
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
|
||
|
def _check_full_rank(store, world_size, timeout):
|
||
|
try:
|
||
|
barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout)
|
||
|
except RuntimeError as e:
|
||
|
if str(e) == _SOCKET_TIMEOUT:
|
||
|
raise TimeoutError(
|
||
|
f"timed out waiting for all {world_size} members to join"
|
||
|
) from e
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
|
||
|
def get_free_port():
|
||
|
"""
|
||
|
Returns an unused port on localhost.
|
||
|
|
||
|
This function finds an unused port on localhost by opening to socket to bind
|
||
|
to a port and then closing it.
|
||
|
|
||
|
Returns:
|
||
|
int: an unused port on localhost
|
||
|
|
||
|
Example:
|
||
|
>>> # xdoctest: +SKIP("Nondeterministic")
|
||
|
>>> get_free_port()
|
||
|
63976
|
||
|
|
||
|
.. note::
|
||
|
The port returned by :func:`get_free_port` is not reserved and may be
|
||
|
taken by another process after this function returns.
|
||
|
"""
|
||
|
sock = get_socket_with_port()
|
||
|
with closing(sock):
|
||
|
return sock.getsockname()[1]
|
||
|
|
||
|
|
||
|
def get_socket_with_port() -> socket.socket:
|
||
|
"""
|
||
|
Returns a free port on localhost that is "reserved" by binding a temporary
|
||
|
socket on it. Close the socket before passing the port to the entity
|
||
|
that requires it. Usage example
|
||
|
|
||
|
::
|
||
|
|
||
|
sock = _get_socket_with_port()
|
||
|
with closing(sock):
|
||
|
port = sock.getsockname()[1]
|
||
|
sock.close()
|
||
|
# there is still a race-condition that some other process
|
||
|
# may grab this port before func() runs
|
||
|
func(port)
|
||
|
"""
|
||
|
|
||
|
addrs = socket.getaddrinfo(
|
||
|
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
|
||
|
)
|
||
|
for addr in addrs:
|
||
|
family, type, proto, _, _ = addr
|
||
|
s = socket.socket(family, type, proto)
|
||
|
try:
|
||
|
s.bind(("localhost", 0))
|
||
|
s.listen(0)
|
||
|
return s
|
||
|
except OSError as e:
|
||
|
s.close()
|
||
|
logger.warning("Socket creation attempt failed.", exc_info=e)
|
||
|
raise RuntimeError("Failed to create a socket")
|