@@ -54,14 +54,14 @@ def find_script_paths(benchmark_type: str) -> list[str]:
5454 return scripts
5555
5656
57- def get_current_tmux_sessions () -> list [str ]:
57+ def get_current_tmux_sessions (benchmark_type : str ) -> list [str ]:
5858 """List active tmux sessions starting with TMUX_SESSION_PREFIX."""
5959 try :
6060 output = subprocess .check_output (["tmux" , "list-sessions" ], stderr = subprocess .DEVNULL )
6161 sessions = [
6262 line .split (b":" )[0 ].decode ()
6363 for line in output .splitlines ()
64- if line .startswith (TMUX_SESSION_PREFIX .encode ())
64+ if line .startswith (f" { TMUX_SESSION_PREFIX } _ { benchmark_type } " .encode ())
6565 ]
6666 return sessions
6767 except subprocess .CalledProcessError :
@@ -72,6 +72,7 @@ def start_tmux_session(
7272 script_path : str ,
7373 persistence_base_dir : Path | str ,
7474 num_experiments : int ,
75+ benchmark_type : str ,
7576 task : str ,
7677 max_epochs : int | None = None ,
7778 epoch_num_steps : int | None = None ,
@@ -84,7 +85,10 @@ def start_tmux_session(
8485
8586 # Include task name in session to avoid collisions when running multiple tasks
8687 script_name = Path (script_path ).name .replace ("_hl.py" , "" )
87- session_name = f"{ TMUX_SESSION_PREFIX } { task } _{ script_name } "
88+ # Remove benchmark_type from name since we add it explicitly below
89+ script_name = script_name .replace (benchmark_type , "" ).strip ("_" )
90+
91+ session_name = f"{ TMUX_SESSION_PREFIX } _{ benchmark_type } _{ task } _{ script_name } "
8892
8993 # Build command with optional max_epochs and epoch_num_steps
9094 cmd_args = f"{ python_exec } { script_path } --num_experiments { num_experiments } --persistence_base_dir { persistence_base_dir } --task { task } "
@@ -161,7 +165,7 @@ def aggregate_rliable_results(task_results_dir: str | Path) -> None:
161165
162166
163167def main (
164- max_concurrent_sessions : int = 2 ,
168+ max_concurrent_sessions : int | None = None ,
165169 benchmark_type : str = "mujoco" ,
166170 num_experiments : int = 10 ,
167171 max_scripts : int = - 1 ,
@@ -177,7 +181,7 @@ def main(
177181 towards the max_concurrent_sessions limit. You can terminate all sessions with
178182 `tmux kill-server`.
179183
180- :param max_concurrent_sessions: how many scripts to run in parallel, each script will
184+ :param max_concurrent_sessions: optionally restrict how many tmux sessions to open in parallel, each script will
181185 run in a tmux session
182186 :param benchmark_type: mujoco or atari
183187 :param num_experiments: number of experiments to run per script
@@ -224,11 +228,11 @@ def main(
224228 for i_script , script in enumerate (scripts , start = 1 ):
225229 # Wait for free slot
226230 has_printed_waiting_message = False
227- while len (get_current_tmux_sessions ()) >= max_concurrent_sessions :
231+ while len (get_current_tmux_sessions (benchmark_type )) >= max_concurrent_sessions :
228232 if not has_printed_waiting_message :
229233 log .info (
230234 f"Max concurrent sessions reached ({ max_concurrent_sessions } ). "
231- f"Current sessions:\n { get_current_tmux_sessions ()} \n Waiting for a free slot..."
235+ f"Current sessions:\n { get_current_tmux_sessions (benchmark_type )} \n Waiting for a free slot..."
232236 )
233237 has_printed_waiting_message = True
234238 time .sleep (SESSION_CHECK_INTERVAL )
@@ -247,11 +251,11 @@ def main(
247251
248252 has_printed_final_waiting_message = False
249253 # Wait for all sessions to complete before moving to next task
250- while len (get_current_tmux_sessions ()) > 0 :
254+ while len (get_current_tmux_sessions (benchmark_type )) > 0 :
251255 if not has_printed_final_waiting_message :
252256 log .info (
253257 f"All scripts for task '{ task } ' have been started, waiting for completion of remaining tmux sessions:\n "
254- f"{ get_current_tmux_sessions ()} "
258+ f"{ get_current_tmux_sessions (benchmark_type )} "
255259 )
256260 has_printed_final_waiting_message = True
257261 time .sleep (COMPLETION_CHECK_INTERVAL )
0 commit comments