Skip to content

Commit def954b

Browse files
authored
Use find free port (#877)
1 parent b54dd6e commit def954b

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

habitat_baselines/rl/ddppo/ddp_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import contextlib
12
import functools
23
import os
34
import signal
5+
import socket
46
import subprocess
57
import threading
68
from os import path as osp
@@ -267,3 +269,21 @@ def init_distrib_slurm(
267269
)
268270

269271
return local_rank, tcp_store
272+
273+
274+
def find_free_port() -> int:
275+
"""
276+
Returns a free port on the system.
277+
Note that this can only be used to find a port for torch.distribted
278+
if it's called by a process on the node that will have
279+
world_rank == 0 and then all ranks are created. If you
280+
just called `find_free_port()` on each rank independently, every
281+
rank will have a different port!
282+
"""
283+
with contextlib.closing(
284+
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
285+
) as sock:
286+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
287+
sock.bind(("localhost", 0))
288+
_, port = sock.getsockname()
289+
return port

test/test_baseline_trainers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from habitat_baselines.common.base_trainer import BaseRLTrainer
2424
from habitat_baselines.common.baseline_registry import baseline_registry
2525
from habitat_baselines.config.default import get_config
26+
from habitat_baselines.rl.ddppo.ddp_utils import find_free_port
2627
from habitat_baselines.run import execute_exp, run_exp
2728
from habitat_baselines.utils.common import (
2829
ObservationBatchingCache,
@@ -77,8 +78,8 @@ def _powerset(s):
7778
),
7879
)
7980
def test_trainers(test_cfg_path, mode, gpu2gpu, observation_transforms):
80-
# For testing with world_size=1, -1 works as port in PyTorch
81-
os.environ["MAIN_PORT"] = str(-1)
81+
# For testing with world_size=1
82+
os.environ["MAIN_PORT"] = str(find_free_port())
8283

8384
if gpu2gpu:
8485
try:

0 commit comments

Comments
 (0)