diff --git a/benchmarks/commit0/run_infer.py b/benchmarks/commit0/run_infer.py index 90e964911..2c7d74a13 100644 --- a/benchmarks/commit0/run_infer.py +++ b/benchmarks/commit0/run_infer.py @@ -1,7 +1,6 @@ import json import os from collections import Counter -from pathlib import Path from typing import Any, List from commit0.harness.constants import SPLIT @@ -13,7 +12,7 @@ get_base_docker_image, ) from benchmarks.commit0.config import INFER_DEFAULTS -from benchmarks.utils.args_parser import get_parser +from benchmarks.utils.args_parser import add_prompt_path_argument, get_parser from benchmarks.utils.console_logging import summarize_instance from benchmarks.utils.constants import EVAL_AGENT_SERVER_IMAGE from benchmarks.utils.conversation import build_event_persistence_callback @@ -593,21 +592,8 @@ def evaluate_instance( def main() -> None: - prompt_dir = (Path(__file__).parent / "prompts").resolve() - choices = [str(p.relative_to(Path.cwd())) for p in prompt_dir.glob("*.j2")] - default_prompt_path = prompt_dir / "default.j2" - assert default_prompt_path.exists(), ( - f"Default prompt {default_prompt_path} not found" - ) - parser = get_parser() - parser.add_argument( - "--prompt-path", - type=str, - default=str(default_prompt_path), - choices=choices, - help="Path to prompt template file", - ) + add_prompt_path_argument(parser, __file__) parser.add_argument( "--repo-split", type=str, diff --git a/benchmarks/multiswebench/run_infer.py b/benchmarks/multiswebench/run_infer.py index 833fcdcc4..e29af8276 100644 --- a/benchmarks/multiswebench/run_infer.py +++ b/benchmarks/multiswebench/run_infer.py @@ -1,6 +1,5 @@ import json import os -from pathlib import Path from typing import List, cast import pandas as pd @@ -13,7 +12,7 @@ ) from benchmarks.multiswebench.download_dataset import download_and_concat_dataset from benchmarks.multiswebench.scripts.data.data_change import format_data_for_inference -from benchmarks.utils.args_parser import get_parser +from benchmarks.utils.args_parser import add_prompt_path_argument, get_parser from benchmarks.utils.build_utils import ensure_local_image from benchmarks.utils.console_logging import summarize_instance from benchmarks.utils.constants import EVAL_AGENT_SERVER_IMAGE @@ -404,21 +403,8 @@ def evaluate_instance( def main() -> None: - prompt_dir = (Path(__file__).parent / "prompts").resolve() - choices = [str(p.relative_to(Path.cwd())) for p in prompt_dir.glob("*.j2")] - default_prompt_path = prompt_dir / "default.j2" - assert default_prompt_path.exists(), ( - f"Default prompt {default_prompt_path} not found" - ) - parser = get_parser() - parser.add_argument( - "--prompt-path", - type=str, - default=str(default_prompt_path), - choices=choices, - help="Path to prompt template file", - ) + add_prompt_path_argument(parser, __file__) parser.add_argument( "--lang", type=str, diff --git a/benchmarks/swebench/run_infer.py b/benchmarks/swebench/run_infer.py index 5021db16e..03bc6c78d 100644 --- a/benchmarks/swebench/run_infer.py +++ b/benchmarks/swebench/run_infer.py @@ -1,6 +1,5 @@ import json import os -from pathlib import Path from typing import List from jinja2 import Environment, FileSystemLoader @@ -13,7 +12,7 @@ wrap_image, ) from benchmarks.swebench.config import INFER_DEFAULTS -from benchmarks.utils.args_parser import get_parser +from benchmarks.utils.args_parser import add_prompt_path_argument, get_parser from benchmarks.utils.build_utils import ensure_local_image from benchmarks.utils.console_logging import summarize_instance from benchmarks.utils.constants import EVAL_AGENT_SERVER_IMAGE @@ -349,21 +348,8 @@ def evaluate_instance( def main() -> None: - prompt_dir = (Path(__file__).parent / "prompts").resolve() - choices = [str(p.relative_to(Path.cwd())) for p in prompt_dir.glob("*.j2")] - default_prompt_path = prompt_dir / "default.j2" - assert default_prompt_path.exists(), ( - f"Default prompt {default_prompt_path} not found" - ) - parser = get_parser() - parser.add_argument( - "--prompt-path", - type=str, - default=str(default_prompt_path), - choices=choices, - help="Path to prompt template file", - ) + add_prompt_path_argument(parser, __file__) parser.set_defaults(**INFER_DEFAULTS) args = parser.parse_args() diff --git a/benchmarks/swebenchmultimodal/run_infer.py b/benchmarks/swebenchmultimodal/run_infer.py index fdf5382ad..ca65f8ca9 100644 --- a/benchmarks/swebenchmultimodal/run_infer.py +++ b/benchmarks/swebenchmultimodal/run_infer.py @@ -1,6 +1,5 @@ import json import os -from pathlib import Path from typing import List import requests @@ -11,7 +10,7 @@ get_official_docker_image, ) from benchmarks.swebenchmultimodal.config import INFER_DEFAULTS -from benchmarks.utils.args_parser import get_parser +from benchmarks.utils.args_parser import add_prompt_path_argument, get_parser from benchmarks.utils.build_utils import ensure_local_image from benchmarks.utils.console_logging import summarize_instance from benchmarks.utils.constants import EVAL_AGENT_SERVER_IMAGE @@ -402,21 +401,8 @@ def evaluate_instance( def main() -> None: - prompt_dir = (Path(__file__).parent / "prompts").resolve() - choices = [str(p.relative_to(Path.cwd())) for p in prompt_dir.glob("*.j2")] - default_prompt_path = prompt_dir / "default.j2" - assert default_prompt_path.exists(), ( - f"Default prompt {default_prompt_path} not found" - ) - parser = get_parser() - parser.add_argument( - "--prompt-path", - type=str, - default=str(default_prompt_path), - choices=choices, - help="Path to prompt template file", - ) + add_prompt_path_argument(parser, __file__) # Apply INFER_DEFAULTS from config (matches evaluation repository values.yaml) parser.set_defaults(**INFER_DEFAULTS) args = parser.parse_args() diff --git a/benchmarks/swefficiency/run_infer.py b/benchmarks/swefficiency/run_infer.py index f559d7f6e..64bd7bb22 100644 --- a/benchmarks/swefficiency/run_infer.py +++ b/benchmarks/swefficiency/run_infer.py @@ -1,7 +1,6 @@ import json import multiprocessing import os -from pathlib import Path from typing import Any, List from jinja2 import Environment, FileSystemLoader @@ -10,7 +9,7 @@ from benchmarks.swefficiency import constants from benchmarks.swefficiency.config import DOCKER_DEFAULTS, INFER_DEFAULTS from benchmarks.swefficiency.workspace import ResourceLimitedDockerWorkspace -from benchmarks.utils.args_parser import get_parser +from benchmarks.utils.args_parser import add_prompt_path_argument, get_parser from benchmarks.utils.build_utils import ensure_local_image from benchmarks.utils.conversation import build_event_persistence_callback from benchmarks.utils.critics import create_critic @@ -395,21 +394,8 @@ def evaluate_instance( def main() -> None: - prompt_dir = (Path(__file__).parent / "prompts").resolve() - choices = [str(p.relative_to(Path.cwd())) for p in prompt_dir.glob("*.j2")] - default_prompt_path = prompt_dir / "default.j2" - assert default_prompt_path.exists(), ( - f"Default prompt {default_prompt_path} not found" - ) - parser = get_parser() - parser.add_argument( - "--prompt-path", - type=str, - default=str(default_prompt_path), - choices=choices, - help="Path to prompt template file", - ) + add_prompt_path_argument(parser, __file__) parser.add_argument( "--num-cpus-per-worker", type=int, diff --git a/benchmarks/swtbench/run_infer.py b/benchmarks/swtbench/run_infer.py index 132949427..8d01c0208 100644 --- a/benchmarks/swtbench/run_infer.py +++ b/benchmarks/swtbench/run_infer.py @@ -1,12 +1,11 @@ import json import os -from pathlib import Path from typing import List from jinja2 import Environment, FileSystemLoader from benchmarks.swtbench.config import INFER_DEFAULTS -from benchmarks.utils.args_parser import get_parser +from benchmarks.utils.args_parser import add_prompt_path_argument, get_parser from benchmarks.utils.console_logging import summarize_instance from benchmarks.utils.constants import EVAL_AGENT_SERVER_IMAGE from benchmarks.utils.conversation import build_event_persistence_callback @@ -338,21 +337,8 @@ def evaluate_instance( def main() -> None: """Main entry point for SWT-bench evaluation.""" - prompt_dir = (Path(__file__).parent / "prompts").resolve() - choices = [str(p.relative_to(Path.cwd())) for p in prompt_dir.glob("*.j2")] - default_prompt_path = prompt_dir / "default.j2" - assert default_prompt_path.exists(), ( - f"Default prompt {default_prompt_path} not found" - ) - parser = get_parser() - parser.add_argument( - "--prompt-path", - type=str, - default=str(default_prompt_path), - choices=choices, - help="Path to prompt template file", - ) + add_prompt_path_argument(parser, __file__) parser.set_defaults(**INFER_DEFAULTS) args = parser.parse_args() diff --git a/benchmarks/utils/args_parser.py b/benchmarks/utils/args_parser.py index 80c18a429..698f2fd4f 100644 --- a/benchmarks/utils/args_parser.py +++ b/benchmarks/utils/args_parser.py @@ -7,6 +7,7 @@ """ import argparse +from pathlib import Path from benchmarks.utils.critics import add_critic_args @@ -101,3 +102,48 @@ def get_parser(add_llm_config: bool = True) -> argparse.ArgumentParser: help="Enable sub-agent delegation tools for the agent", ) return parser + + +def add_prompt_path_argument(parser: argparse.ArgumentParser, caller_file: str) -> None: + """Add --prompt-path argument with choices from the benchmark's prompts/ dir. + + Resolves prompt templates relative to the caller's directory rather than + CWD, so the argument works regardless of where the process is launched. + + Users can pass a bare filename (e.g. ``default.j2``), which is resolved + against the benchmark's ``prompts/`` directory, or a full path to any + ``.j2`` file for backwards compatibility. The parsed value is always an + absolute path so downstream code can rely on it directly. + + Args: + parser: The argument parser to add the argument to. + caller_file: Pass ``__file__`` from the calling module so we can + locate its sibling ``prompts/`` directory. + """ + prompt_dir = (Path(caller_file).parent / "prompts").resolve() + templates = sorted(p.name for p in prompt_dir.glob("*.j2")) + assert (prompt_dir / "default.j2").exists(), ( + f"Default prompt {prompt_dir / 'default.j2'} not found" + ) + + def _resolve_prompt(value: str) -> str: + """Resolve a filename or path to an absolute prompt template path.""" + # Accept bare filenames (e.g. "default.j2") and resolve them. + candidate = prompt_dir / Path(value).name + if candidate.is_file(): + return str(candidate) + # Also accept absolute/relative paths for backwards compatibility. + p = Path(value) + if p.is_file(): + return str(p.resolve()) + raise argparse.ArgumentTypeError( + f"Prompt template not found: {value!r}. Available: {', '.join(templates)}" + ) + + parser.add_argument( + "--prompt-path", + type=_resolve_prompt, + default=str(prompt_dir / "default.j2"), + metavar="{" + ",".join(templates) + "}", + help="Prompt template filename (default: default.j2)", + ) diff --git a/tests/test_prompt_path.py b/tests/test_prompt_path.py new file mode 100644 index 000000000..94630d8ff --- /dev/null +++ b/tests/test_prompt_path.py @@ -0,0 +1,88 @@ +"""Tests for add_prompt_path_argument utility.""" + +import argparse +from pathlib import Path + +import pytest + +from benchmarks.utils.args_parser import add_prompt_path_argument + + +@pytest.fixture() +def prompt_tree(tmp_path: Path) -> Path: + """Create a minimal benchmark directory with a prompts/ sub-directory.""" + prompts_dir = tmp_path / "prompts" + prompts_dir.mkdir() + (prompts_dir / "default.j2").write_text("{{ task }}") + (prompts_dir / "custom.j2").write_text("{{ task }} (custom)") + # Fake caller module living next to the prompts/ dir. + caller = tmp_path / "run_infer.py" + caller.write_text("") + return caller + + +class TestAddPromptPathArgument: + """Test that --prompt-path resolves correctly in all scenarios.""" + + def test_default_value_is_absolute(self, prompt_tree: Path) -> None: + parser = argparse.ArgumentParser() + add_prompt_path_argument(parser, str(prompt_tree)) + args = parser.parse_args([]) + assert Path(args.prompt_path).is_absolute() + assert args.prompt_path.endswith("default.j2") + + def test_bare_filename_resolves_to_absolute(self, prompt_tree: Path) -> None: + parser = argparse.ArgumentParser() + add_prompt_path_argument(parser, str(prompt_tree)) + args = parser.parse_args(["--prompt-path", "custom.j2"]) + assert Path(args.prompt_path).is_absolute() + assert args.prompt_path.endswith("custom.j2") + + def test_absolute_path_accepted(self, prompt_tree: Path) -> None: + abs_path = str(prompt_tree.parent / "prompts" / "custom.j2") + parser = argparse.ArgumentParser() + add_prompt_path_argument(parser, str(prompt_tree)) + args = parser.parse_args(["--prompt-path", abs_path]) + assert args.prompt_path == abs_path + + def test_works_from_different_cwd( + self, prompt_tree: Path, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """The original bug: ValueError when CWD != project root.""" + # Change CWD to an unrelated directory. + other_dir = tmp_path / "elsewhere" + other_dir.mkdir() + monkeypatch.chdir(other_dir) + + parser = argparse.ArgumentParser() + add_prompt_path_argument(parser, str(prompt_tree)) + args = parser.parse_args([]) + assert Path(args.prompt_path).is_file() + + def test_invalid_template_gives_clear_error(self, prompt_tree: Path) -> None: + parser = argparse.ArgumentParser() + add_prompt_path_argument(parser, str(prompt_tree)) + with pytest.raises(SystemExit): + parser.parse_args(["--prompt-path", "nonexistent.j2"]) + + def test_missing_default_template_raises(self, tmp_path: Path) -> None: + prompts_dir = tmp_path / "prompts" + prompts_dir.mkdir() + (prompts_dir / "other.j2").write_text("no default") + caller = tmp_path / "run_infer.py" + caller.write_text("") + + parser = argparse.ArgumentParser() + with pytest.raises(AssertionError, match="default.j2"): + add_prompt_path_argument(parser, str(caller)) + + def test_relative_path_accepted( + self, prompt_tree: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Backwards compatibility: relative paths that exist are resolved.""" + monkeypatch.chdir(prompt_tree.parent) + parser = argparse.ArgumentParser() + add_prompt_path_argument(parser, str(prompt_tree)) + args = parser.parse_args(["--prompt-path", "prompts/custom.j2"]) + assert Path(args.prompt_path).is_absolute() + assert args.prompt_path.endswith("custom.j2")