33import sys
44import time
55from pathlib import Path
6+ from typing import Literal
67
78from sensai .util import logging
89from sensai .util .logging import datetime_tag
@@ -76,6 +77,7 @@ def start_tmux_session(
7677 task : str ,
7778 max_epochs : int | None = None ,
7879 epoch_num_steps : int | None = None ,
80+ experiment_launcher : Literal ["sequential" , "joblib" ] | None = None ,
7981) -> bool :
8082 """Start a tmux session running the given experiment script, returning True on success."""
8183 # Normalize paths for Git Bash / Windows compatibility
@@ -96,6 +98,8 @@ def start_tmux_session(
9698 cmd_args += f" --max_epochs { max_epochs } "
9799 if epoch_num_steps is not None :
98100 cmd_args += f" --epoch_num_steps { epoch_num_steps } "
101+ if experiment_launcher is not None :
102+ cmd_args += f" --experiment_launcher { experiment_launcher } "
99103
100104 cmd = [
101105 "tmux" ,
@@ -166,13 +170,14 @@ def aggregate_rliable_results(task_results_dir: str | Path) -> None:
166170
167171def main (
168172 max_concurrent_sessions : int | None = None ,
169- benchmark_type : str = "mujoco" ,
173+ benchmark_type : Literal [ "mujoco" , "atari" ] = "mujoco" ,
170174 num_experiments : int = 10 ,
171175 max_scripts : int = - 1 ,
172176 tasks : list [str ] | None = None ,
173177 max_tasks : int = - 1 ,
174178 max_epochs : int | None = None ,
175179 epoch_num_steps : int | None = None ,
180+ experiment_launcher : Literal ["sequential" , "joblib" ] | None = None ,
176181) -> None :
177182 """
178183 Run the benchmarking by executing each high level script in its default configuration
@@ -190,6 +195,8 @@ def main(
190195 :param max_tasks: maximum number of tasks to run, -1 for all. Set this to a low number for testing.
191196 :param max_epochs: optional maximum number of training epochs to pass to all scripts. If None, uses script defaults.
192197 :param epoch_num_steps: optional number of environment steps per epoch to pass to all scripts. If None, uses script defaults.
198+ :param experiment_launcher: type of experiment launcher to use, only has an effect if `num_experiments>1`.
199+ By default, will use the experiment launchers defined in the individual scripts.
193200 :return:
194201 """
195202 # Use default tasks if none provided
@@ -248,6 +255,7 @@ def main(
248255 task = task ,
249256 max_epochs = max_epochs ,
250257 epoch_num_steps = epoch_num_steps ,
258+ experiment_launcher = experiment_launcher ,
251259 )
252260 if session_started :
253261 time .sleep (TMUX_SESSION_START_DELAY ) # Give tmux a moment to start the session
0 commit comments