Skip to content

Commit

Permalink
move searching ports to distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 29, 2024
1 parent 1cf4da2 commit f6d9847
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
6 changes: 5 additions & 1 deletion src/nanotron/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch.distributed import * # noqa
from torch.distributed.distributed_c10d import ProcessGroup

from nanotron.utils import find_free_port

torch_version_above_1_13 = version.parse(torch.__version__) >= version.parse("1.13.0")
Work = dist.Work if torch_version_above_1_13 else dist._Work
default_pg_timeout = datetime.timedelta(minutes=10)
Expand Down Expand Up @@ -257,5 +259,7 @@ def initialize_torch_distributed():
backend = "gloo"

# Call the init process.
dist.init_process_group(backend=backend, world_size=world_size, rank=rank, timeout=dist.default_pg_timeout)
port = find_free_port()
init_method = f"tcp://localhost:{port}"
dist.init_process_group(init_method=init_method, backend=backend, world_size=world_size, rank=rank, timeout=dist.default_pg_timeout)
return True
14 changes: 14 additions & 0 deletions src/nanotron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
from contextlib import ExitStack, contextmanager
from typing import Callable, ContextManager, List, Optional
import random
import socket

import torch
from packaging import version
Expand Down Expand Up @@ -147,3 +149,15 @@ def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: to
tensor = torch.empty([], dtype=dtype, device=device)
tensor.set_(source=untyped_storage)
return tensor


def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int:
while True:
port = random.randint(min_port, max_port)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError as e:
raise e
17 changes: 0 additions & 17 deletions tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,11 @@
import os
import uuid
from typing import Any, Dict, List, Optional, Tuple
# import random
# import socket

import torch.cuda
from nanotron.parallel import ParallelContext
from torch.distributed.launcher import elastic_launch


# def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int:
# while True:
# port = random.randint(min_port, max_port)
# try:
# with socket.socket() as sock:
# sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# sock.bind(("localhost", port))
# return port
# except OSError as e:
# raise e

def available_gpus():
if not torch.cuda.is_available():
return 0
Expand Down Expand Up @@ -106,8 +92,6 @@ def _init_distributed(func):
nb_gpus = tp * dp * pp
run_id = uuid.uuid4()

# port = find_free_port()

config = torch.distributed.launcher.LaunchConfig(
min_nodes=1,
max_nodes=1,
Expand All @@ -116,7 +100,6 @@ def _init_distributed(func):
rdzv_configs={"timeout": 60},
# Setting port to `0` allows `torch` to randomly pick a port: https://pytorch.org/docs/stable/elastic/run.html#stacked-single-node-multi-worker
# Works only for single node workload.
# rdzv_endpoint=f"localhost:{port}",
rdzv_endpoint=f"localhost:0",
run_id=str(run_id),
max_restarts=0,
Expand Down

0 comments on commit f6d9847

Please sign in to comment.