Skip to content

Commit 51493eb

Browse files
committed
Configurable experiment launcher in run_benchmark.py
1 parent f368fe7 commit 51493eb

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

benchmark/run_benchmark.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import time
55
from pathlib import Path
6+
from typing import Literal
67

78
from sensai.util import logging
89
from sensai.util.logging import datetime_tag
@@ -76,6 +77,7 @@ def start_tmux_session(
7677
task: str,
7778
max_epochs: int | None = None,
7879
epoch_num_steps: int | None = None,
80+
experiment_launcher: Literal["sequential", "joblib"] | None = None,
7981
) -> bool:
8082
"""Start a tmux session running the given experiment script, returning True on success."""
8183
# Normalize paths for Git Bash / Windows compatibility
@@ -96,6 +98,8 @@ def start_tmux_session(
9698
cmd_args += f" --max_epochs {max_epochs}"
9799
if epoch_num_steps is not None:
98100
cmd_args += f" --epoch_num_steps {epoch_num_steps}"
101+
if experiment_launcher is not None:
102+
cmd_args += f" --experiment_launcher {experiment_launcher}"
99103

100104
cmd = [
101105
"tmux",
@@ -166,13 +170,14 @@ def aggregate_rliable_results(task_results_dir: str | Path) -> None:
166170

167171
def main(
168172
max_concurrent_sessions: int | None = None,
169-
benchmark_type: str = "mujoco",
173+
benchmark_type: Literal["mujoco", "atari"] = "mujoco",
170174
num_experiments: int = 10,
171175
max_scripts: int = -1,
172176
tasks: list[str] | None = None,
173177
max_tasks: int = -1,
174178
max_epochs: int | None = None,
175179
epoch_num_steps: int | None = None,
180+
experiment_launcher: Literal["sequential", "joblib"] | None = None,
176181
) -> None:
177182
"""
178183
Run the benchmarking by executing each high level script in its default configuration
@@ -190,6 +195,8 @@ def main(
190195
:param max_tasks: maximum number of tasks to run, -1 for all. Set this to a low number for testing.
191196
:param max_epochs: optional maximum number of training epochs to pass to all scripts. If None, uses script defaults.
192197
:param epoch_num_steps: optional number of environment steps per epoch to pass to all scripts. If None, uses script defaults.
198+
:param experiment_launcher: type of experiment launcher to use, only has an effect if `num_experiments>1`.
199+
By default, will use the experiment launchers defined in the individual scripts.
193200
:return:
194201
"""
195202
# Use default tasks if none provided
@@ -248,6 +255,7 @@ def main(
248255
task=task,
249256
max_epochs=max_epochs,
250257
epoch_num_steps=epoch_num_steps,
258+
experiment_launcher=experiment_launcher,
251259
)
252260
if session_started:
253261
time.sleep(TMUX_SESSION_START_DELAY) # Give tmux a moment to start the session

0 commit comments

Comments
 (0)