Skip to content

Commit c3cf417

Browse files
committed
Benchmark script: minor improvement in tmux session counting
1 parent 2bcb29f commit c3cf417

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

benchmark/run_benchmark.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def find_script_paths(benchmark_type: str) -> list[str]:
5454
return scripts
5555

5656

57-
def get_current_tmux_sessions() -> list[str]:
57+
def get_current_tmux_sessions(benchmark_type: str) -> list[str]:
5858
"""List active tmux sessions starting with TMUX_SESSION_PREFIX."""
5959
try:
6060
output = subprocess.check_output(["tmux", "list-sessions"], stderr=subprocess.DEVNULL)
6161
sessions = [
6262
line.split(b":")[0].decode()
6363
for line in output.splitlines()
64-
if line.startswith(TMUX_SESSION_PREFIX.encode())
64+
if line.startswith(f"{TMUX_SESSION_PREFIX}_{benchmark_type}".encode())
6565
]
6666
return sessions
6767
except subprocess.CalledProcessError:
@@ -72,6 +72,7 @@ def start_tmux_session(
7272
script_path: str,
7373
persistence_base_dir: Path | str,
7474
num_experiments: int,
75+
benchmark_type: str,
7576
task: str,
7677
max_epochs: int | None = None,
7778
epoch_num_steps: int | None = None,
@@ -84,7 +85,10 @@ def start_tmux_session(
8485

8586
# Include task name in session to avoid collisions when running multiple tasks
8687
script_name = Path(script_path).name.replace("_hl.py", "")
87-
session_name = f"{TMUX_SESSION_PREFIX}{task}_{script_name}"
88+
# Remove benchmark_type from name since we add it explicitly below
89+
script_name = script_name.replace(benchmark_type, "").strip("_")
90+
91+
session_name = f"{TMUX_SESSION_PREFIX}_{benchmark_type}_{task}_{script_name}"
8892

8993
# Build command with optional max_epochs and epoch_num_steps
9094
cmd_args = f"{python_exec} {script_path} --num_experiments {num_experiments} --persistence_base_dir {persistence_base_dir} --task {task}"
@@ -161,7 +165,7 @@ def aggregate_rliable_results(task_results_dir: str | Path) -> None:
161165

162166

163167
def main(
164-
max_concurrent_sessions: int = 2,
168+
max_concurrent_sessions: int | None = None,
165169
benchmark_type: str = "mujoco",
166170
num_experiments: int = 10,
167171
max_scripts: int = -1,
@@ -177,7 +181,7 @@ def main(
177181
towards the max_concurrent_sessions limit. You can terminate all sessions with
178182
`tmux kill-server`.
179183
180-
:param max_concurrent_sessions: how many scripts to run in parallel, each script will
184+
:param max_concurrent_sessions: optionally restrict how many tmux sessions to open in parallel, each script will
181185
run in a tmux session
182186
:param benchmark_type: mujoco or atari
183187
:param num_experiments: number of experiments to run per script
@@ -224,11 +228,11 @@ def main(
224228
for i_script, script in enumerate(scripts, start=1):
225229
# Wait for free slot
226230
has_printed_waiting_message = False
227-
while len(get_current_tmux_sessions()) >= max_concurrent_sessions:
231+
while len(get_current_tmux_sessions(benchmark_type)) >= max_concurrent_sessions:
228232
if not has_printed_waiting_message:
229233
log.info(
230234
f"Max concurrent sessions reached ({max_concurrent_sessions}). "
231-
f"Current sessions:\n{get_current_tmux_sessions()}\nWaiting for a free slot..."
235+
f"Current sessions:\n{get_current_tmux_sessions(benchmark_type)}\nWaiting for a free slot..."
232236
)
233237
has_printed_waiting_message = True
234238
time.sleep(SESSION_CHECK_INTERVAL)
@@ -247,11 +251,11 @@ def main(
247251

248252
has_printed_final_waiting_message = False
249253
# Wait for all sessions to complete before moving to next task
250-
while len(get_current_tmux_sessions()) > 0:
254+
while len(get_current_tmux_sessions(benchmark_type)) > 0:
251255
if not has_printed_final_waiting_message:
252256
log.info(
253257
f"All scripts for task '{task}' have been started, waiting for completion of remaining tmux sessions:\n"
254-
f"{get_current_tmux_sessions()}"
258+
f"{get_current_tmux_sessions(benchmark_type)}"
255259
)
256260
has_printed_final_waiting_message = True
257261
time.sleep(COMPLETION_CHECK_INTERVAL)

0 commit comments

Comments
 (0)