Skip to content

Commit f28c819

Browse files
committed
Added a script for benchmarking
1 parent a34d7f8 commit f28c819

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

benchmark/run_benchmark.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import subprocess
2+
import sys
3+
import time
4+
from pathlib import Path
5+
6+
from sensai.util import logging
7+
from sensai.util.logging import datetime_tag
8+
9+
LOG_FILE = "session_log.txt"
10+
ERROR_LOG_FILE = "error_log.txt"
11+
TMUX_SESSION_PREFIX = "tianshou_"
12+
13+
log = logging.getLogger("benchmark")
14+
15+
16+
def find_script_paths(benchmark_type: str) -> list[str]:
17+
"""Return all Python scripts ending in _hl.py under examples/<benchmark_type>."""
18+
base_dir = Path(__file__).parent.parent / "examples" / benchmark_type
19+
glob_filter = "**/*_hl.py"
20+
if not base_dir.exists():
21+
raise FileNotFoundError(f"Directory '{base_dir}' does not exist.")
22+
23+
scripts = sorted(str(p) for p in base_dir.glob(glob_filter))
24+
if not scripts:
25+
raise FileNotFoundError(f"Did not find any scripts matching '*_hl.py' in '{base_dir}'.")
26+
return scripts
27+
28+
29+
def get_current_tmux_sessions() -> list[str]:
30+
"""List active tmux sessions starting with 'job_'."""
31+
try:
32+
output = subprocess.check_output(["tmux", "list-sessions"], stderr=subprocess.DEVNULL)
33+
sessions = [
34+
line.split(b":")[0].decode()
35+
for line in output.splitlines()
36+
if line.startswith(TMUX_SESSION_PREFIX.encode())
37+
]
38+
return sessions
39+
except subprocess.CalledProcessError:
40+
return []
41+
42+
43+
def start_tmux_session(script_path: str) -> bool:
44+
"""Start a tmux session running the given Python script, returning True on success."""
45+
# Normalize paths for Git Bash / Windows compatibility
46+
python_exec = sys.executable.replace("\\", "/")
47+
script_path = script_path.replace("\\", "/")
48+
session_name = TMUX_SESSION_PREFIX + Path(script_path).name.replace("_hl.py", "")
49+
num_experiments = 5 # always 5 experiments to get rliable evaluations
50+
51+
cmd = [
52+
"tmux",
53+
"new-session",
54+
"-d",
55+
"-s",
56+
session_name,
57+
f"{python_exec} {script_path} --num_experiments {num_experiments}; echo 'Finished {script_path}'; tmux kill-session -t {session_name}",
58+
]
59+
try:
60+
subprocess.run(cmd, check=True)
61+
log.info(
62+
f"Started {script_path} in session '{session_name}'. Attach with:\ntmux attach -t {session_name}"
63+
)
64+
return True
65+
except subprocess.CalledProcessError as e:
66+
log.error(f"Failed to start {script_path} (session {session_name}): {e}")
67+
return False
68+
69+
70+
def main(max_concurrent_sessions: int = 2, benchmark_type: str = "mujoco"):
71+
"""
72+
Run the benchmarking by executing each high level script in its default configuration
73+
(apart from num_experiments, which will be set to 5) in its own tmux session.
74+
Note that if you have unclosed tmux sessions from previous runs, those will count
75+
towards the max_concurrent_sessions limit. You can terminate all sessions with
76+
`tmux kill-server`.
77+
78+
:param max_concurrent_sessions: how many scripts to run in parallel, each script will
79+
run in a tmux session
80+
:param benchmark_type: mujoco or atari
81+
:return:
82+
"""
83+
log_file = Path(__file__).parent / "logs" / f"benchmarking_{datetime_tag()}.txt"
84+
log_file.parent.mkdir(parents=True, exist_ok=True)
85+
logging.add_file_logger(log_file, append=False)
86+
87+
log.info(
88+
f"=== Starting benchmark batch for '{benchmark_type}' with {max_concurrent_sessions} concurrent jobs ==="
89+
)
90+
scripts = find_script_paths(benchmark_type)
91+
log.info(f"Found {len(scripts)} scripts to run.")
92+
93+
for i, script in enumerate(scripts, start=1):
94+
# Wait for free slot
95+
has_printed_waiting = False
96+
while len(get_current_tmux_sessions()) >= max_concurrent_sessions:
97+
if not has_printed_waiting:
98+
log.info(
99+
f"Max concurrent sessions reached ({max_concurrent_sessions}). "
100+
f"Current sessions:\n{get_current_tmux_sessions()}\nWaiting for a free slot..."
101+
)
102+
has_printed_waiting = True
103+
time.sleep(5)
104+
105+
log.info(f"Starting script {i}/{len(scripts)}")
106+
session_started = start_tmux_session(script)
107+
if session_started:
108+
time.sleep(2) # Give tmux a moment to start the session
109+
110+
log.info("All jobs have been started.")
111+
log.info("Use 'tmux ls' to list all active sessions.")
112+
log.info("Use 'tmux attach -t <session_name>' to attach to a running session.")
113+
log.info("===============================================================")
114+
115+
116+
if __name__ == "__main__":
117+
logging.run_cli(main)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ ignore = [
190190
"PLC0415", # local imports
191191
"SIM108", # if else is fine instead of ternary
192192
"PLW1641", # weird thing requiring __hash__ for Protocol
193+
"RET504", # sometimes we want to assign before return (e.g. for debugging)
193194
]
194195
unfixable = [
195196
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all

0 commit comments

Comments
 (0)