Skip to content

Commit 948961b

Browse files
authored
Parallelize tests for faster PR turnaround time (#372)
* Enhance CI workflow to support sharded testing for core tests. Update test names to include shard information and streamline test execution by replacing inline pytest commands with a dedicated script for sharded tests. Remove coverage upload steps for batch evaluation and MCP end-to-end tests.
1 parent 697be66 commit 948961b

File tree

5 files changed

+98
-120
lines changed

5 files changed

+98
-120
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 0 additions & 60 deletions
This file was deleted.

.github/workflows/ci.yml

Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ jobs:
5353
uv run basedpyright || true
5454
5555
test-core:
56-
name: Core Tests (Python ${{ matrix.python-version }})
56+
name: Core Tests (Python ${{ matrix.python-version }}, Shard ${{ matrix.shard }}/${{ matrix.total-shards }})
5757
runs-on: ubuntu-latest
5858
needs: lint-and-type-check
5959
strategy:
6060
fail-fast: false
6161
matrix:
6262
python-version: ["3.10", "3.11", "3.12"]
63+
shard: [1, 2, 3, 4]
64+
total-shards: [4]
6365

6466
steps:
6567
- uses: actions/checkout@v4
@@ -82,7 +84,7 @@ jobs:
8284
- name: Install tau2 for testing
8385
run: uv pip install git+https://github.com/sierra-research/tau2-bench.git@main
8486

85-
- name: Run Core Tests with pytest-xdist
87+
- name: Run Core Tests with pytest-xdist (Shard ${{ matrix.shard }}/${{ matrix.total-shards }})
8688
env:
8789
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
8890
E2B_API_KEY: ${{ secrets.E2B_API_KEY }}
@@ -94,31 +96,7 @@ jobs:
9496
SUPABASE_DATABASE: ${{ secrets.SUPABASE_DATABASE }}
9597
SUPABASE_USER: ${{ secrets.SUPABASE_USER }}
9698
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
97-
run: |
98-
# Run most tests in parallel, but explicitly ignore tests that manage their own servers or are slow
99-
uv run pytest \
100-
-n auto \
101-
--ignore=tests/test_batch_evaluation.py \
102-
--ignore=tests/pytest/test_frozen_lake.py \
103-
--ignore=tests/pytest/test_lunar_lander.py \
104-
--ignore=tests/pytest/test_tau_bench_airline.py \
105-
--ignore=tests/pytest/test_apps_coding.py \
106-
--ignore=tests/test_tau_bench_airline_smoke.py \
107-
--ignore=tests/pytest/test_svgbench.py \
108-
--ignore=tests/pytest/test_livesvgbench.py \
109-
--ignore=tests/remote_server/test_remote_fireworks.py \
110-
--ignore=tests/remote_server/test_remote_fireworks_propagate_status.py \
111-
--ignore=tests/logging/test_elasticsearch_direct_http_handler.py \
112-
--ignore=eval_protocol/benchmarks/ \
113-
--ignore=eval_protocol/quickstart/ \
114-
--cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10
115-
116-
- name: Store coverage file
117-
uses: actions/upload-artifact@v4
118-
with:
119-
name: coverage-core-${{ matrix.python-version }}
120-
path: coverage.xml
121-
retention-days: 1
99+
run: uv run ./scripts/run_sharded_tests.sh ${{ matrix.shard }} ${{ matrix.total-shards }}
122100

123101
test-batch-evaluation:
124102
name: Batch Evaluation Tests
@@ -153,13 +131,7 @@ jobs:
153131
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
154132
run: |
155133
# Run only this specific test file, WITHOUT xdist
156-
uv run pytest tests/test_batch_evaluation.py --cov=eval_protocol --cov-append --cov-report=xml -v --durations=10
157-
- name: Store coverage file
158-
uses: actions/upload-artifact@v4
159-
with:
160-
name: coverage-batch-eval
161-
path: coverage.xml
162-
retention-days: 1
134+
uv run pytest tests/test_batch_evaluation.py -v --durations=10
163135
164136
test-mcp-e2e:
165137
name: MCP End-to-End Tests
@@ -183,27 +155,3 @@ jobs:
183155

184156
- name: Install tau2 for testing
185157
run: uv pip install git+https://github.com/sierra-research/tau2-bench.git@main
186-
187-
- name: Store coverage file
188-
uses: actions/upload-artifact@v4
189-
with:
190-
name: coverage-mcp-e2e
191-
path: coverage.xml
192-
retention-days: 1
193-
194-
upload-coverage:
195-
name: Upload Coverage
196-
runs-on: ubuntu-latest
197-
needs: [test-core, test-batch-evaluation, test-mcp-e2e]
198-
steps:
199-
- name: Download all coverage artifacts
200-
uses: actions/download-artifact@v4
201-
with:
202-
path: coverage-artifacts
203-
- name: Upload coverage to Codecov
204-
uses: codecov/codecov-action@v3
205-
with:
206-
token: ${{ secrets.CODECOV_TOKEN }}
207-
directory: ./coverage-artifacts/
208-
fail_ci_if_error: false
209-
verbose: true

scripts/run_sharded_tests.sh

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env bash
2+
# Script to run a shard of tests for parallel CI execution
3+
# Usage: ./scripts/run_sharded_tests.sh <shard> <total_shards> [--dry-run]
4+
# Example: ./scripts/run_sharded_tests.sh 1 4
5+
# Example: ./scripts/run_sharded_tests.sh 1 4 --dry-run
6+
7+
set -e
8+
9+
SHARD=${1:-1}
10+
TOTAL_SHARDS=${2:-4}
11+
DRY_RUN=${3:-""}
12+
13+
if [ "$SHARD" -lt 1 ] || [ "$SHARD" -gt "$TOTAL_SHARDS" ]; then
14+
echo "Error: Shard must be between 1 and $TOTAL_SHARDS"
15+
exit 1
16+
fi
17+
18+
# Collect all test files, excluding ignored ones
19+
TEST_FILES=$(find tests -name "test_*.py" \
20+
! -path "tests/test_batch_evaluation.py" \
21+
! -path "tests/pytest/test_frozen_lake.py" \
22+
! -path "tests/pytest/test_lunar_lander.py" \
23+
! -path "tests/pytest/test_tau_bench_airline.py" \
24+
! -path "tests/pytest/test_apps_coding.py" \
25+
! -path "tests/test_tau_bench_airline_smoke.py" \
26+
! -path "tests/pytest/test_svgbench.py" \
27+
! -path "tests/pytest/test_livesvgbench.py" \
28+
! -path "tests/remote_server/test_remote_fireworks.py" \
29+
! -path "tests/remote_server/test_remote_fireworks_propagate_status.py" \
30+
! -path "tests/logging/test_elasticsearch_direct_http_handler.py" \
31+
| sort)
32+
33+
# Count total files
34+
TOTAL_FILES=$(echo "$TEST_FILES" | wc -l | tr -d ' ')
35+
36+
# Calculate start and end line numbers for this shard (1-indexed for sed)
37+
FILES_PER_SHARD=$(( (TOTAL_FILES + TOTAL_SHARDS - 1) / TOTAL_SHARDS ))
38+
START_LINE=$(( (SHARD - 1) * FILES_PER_SHARD + 1 ))
39+
END_LINE=$(( START_LINE + FILES_PER_SHARD - 1 ))
40+
if [ $END_LINE -gt $TOTAL_FILES ]; then
41+
END_LINE=$TOTAL_FILES
42+
fi
43+
44+
# Get files for this shard using sed
45+
SHARD_FILES=$(echo "$TEST_FILES" | sed -n "${START_LINE},${END_LINE}p")
46+
SHARD_COUNT=$(echo "$SHARD_FILES" | grep -c . || echo 0)
47+
48+
echo "========================================"
49+
echo "Running shard $SHARD of $TOTAL_SHARDS"
50+
echo "========================================"
51+
echo "Total test files: $TOTAL_FILES"
52+
echo "Files per shard: ~$FILES_PER_SHARD"
53+
echo "Files in this shard: $SHARD_COUNT"
54+
echo "Line range: $START_LINE to $END_LINE"
55+
echo "----------------------------------------"
56+
echo "Files:"
57+
echo "$SHARD_FILES" | while read -r f; do
58+
echo " $f"
59+
done
60+
echo "----------------------------------------"
61+
62+
if [ "$SHARD_COUNT" -eq 0 ] || [ -z "$SHARD_FILES" ]; then
63+
echo "No files in this shard, skipping tests"
64+
exit 0
65+
fi
66+
67+
# Check if --dry-run flag is passed
68+
if [ "$DRY_RUN" = "--dry-run" ]; then
69+
echo "Dry run mode - not executing tests"
70+
exit 0
71+
fi
72+
73+
# Run tests for this shard
74+
# shellcheck disable=SC2086
75+
exec pytest \
76+
-n auto \
77+
--ignore=eval_protocol/benchmarks/ \
78+
--ignore=eval_protocol/quickstart/ \
79+
-v --durations=10 \
80+
$SHARD_FILES

tests/pytest/test_pytest_klavis_mcp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ async def test_pytest_klavis_mcp(row: EvaluationRow) -> EvaluationRow:
5252
)
5353
response_text = response.choices[0].message.content
5454
logger.info("response_text: %s", response_text)
55-
score = json.loads(response_text or "{}")["score"]
55+
try:
56+
parsed = json.loads(response_text or "{}")
57+
score = parsed.get("score", 0.0)
58+
except (json.JSONDecodeError, TypeError):
59+
logger.warning("Failed to parse response as JSON: %s", response_text)
60+
score = 0.0
5661

5762
row.evaluation_result = EvaluateResult(
5863
score=score,

tests/test_cli_create_rft.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
10311031
assert captured["jsonl_path"] == str(jsonl_path)
10321032

10331033

1034-
def test_cli_full_command_style_evaluator_and_dataset_flags(monkeypatch):
1034+
def test_cli_full_command_style_evaluator_and_dataset_flags(tmp_path, monkeypatch):
1035+
# Isolate CWD so _discover_tests doesn't run pytest in the real project
1036+
project = tmp_path / "proj"
1037+
project.mkdir()
1038+
monkeypatch.chdir(project)
1039+
10351040
# Env
10361041
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
10371042
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "pyroworks-dev")

0 commit comments

Comments
 (0)