Skip to content

Commit 33f4983

Browse files
authored
Add regression test script and max_queries arg (#246)
- added scripts for regression tests, the tests run with a select number of queries (max_queries), on a mix of datasets and retrievers.
1 parent baf3b39 commit 33f4983

File tree

4 files changed

+76
-6
lines changed

4 files changed

+76
-6
lines changed

regression_test.sh

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/bin/bash
2+
3+
# Define test cases:
4+
# Format:
5+
# TEST_NAMES=("Test 1" "Test 2" ...)
6+
# TEST_COMMANDS=("command1" "command2" ...)
7+
# TEST_EXPECTED_SCORES=(0.123 0.456 ...)
8+
9+
TEST_NAMES=(
10+
"FirstMistral (Alpha, Logits)"
11+
"RZ"
12+
"Qwen (Alpha)"
13+
"Monot5"
14+
"Duot5"
15+
)
16+
17+
TEST_COMMANDS=(
18+
"python src/rank_llm/scripts/run_rank_llm.py --model_path=castorini/first_mistral --top_k_candidates=50 --dataset=dl19 --retrieval_method=bm25 --prompt_mode=rank_GPT --context_size=4096 --use_alpha --use_logits --max_queries=3"
19+
"python src/rank_llm/scripts/run_rank_llm.py --model_path=castorini/rank_zephyr_7b_v1_full --top_k_candidates=50 --dataset=dl20 --retrieval_method=SPLADE++_EnsembleDistil_ONNX --prompt_mode=rank_GPT --context_size=4096 --max_queries=3"
20+
"python src/rank_llm/scripts/run_rank_llm.py --model_path=Qwen/Qwen2.5-7B-Instruct --top_k_candidates=50 --dataset=dl21 --retrieval_method=bm25 --prompt_mode=rank_GPT --context_size=4096 --variable_passages --max_queries=3"
21+
"python src/rank_llm/scripts/run_rank_llm.py --model_path=castorini/monot5-3b-msmarco-10k --top_k_candidates=50 --dataset=dl22 --retrieval_method=bm25 --prompt_mode=rank_GPT --context_size=4096 --variable_passages --max_queries=3"
22+
"python src/rank_llm/scripts/run_rank_llm.py --model_path=castorini/duot5-3b-msmarco-10k --top_k_candidates=50 --dataset=dl23 --retrieval_method=bm25 --prompt_mode=rank_GPT --context_size=4096 --variable_passages --max_queries=1"
23+
)
24+
25+
TEST_EXPECTED_SCORES=(
26+
0.8085
27+
0.7662
28+
0.7157
29+
0.3997
30+
0.7246
31+
)
32+
33+
for i in "${!TEST_NAMES[@]}"; do
34+
NAME="${TEST_NAMES[$i]}"
35+
COMMAND="${TEST_COMMANDS[$i]}"
36+
EXPECTED_SCORE="${TEST_EXPECTED_SCORES[$i]}"
37+
38+
echo "Running $NAME..."
39+
40+
OUTPUT=$(eval "$COMMAND" 2>&1)
41+
42+
SCORE=$(echo "$OUTPUT" | grep -oP 'ndcg_cut_10\s+all\s+\K\d+\.\d+')
43+
44+
if [ -z "$SCORE" ]; then
45+
echo "❌ ERROR: Could not extract nDCG@10 score for '$NAME'"
46+
continue
47+
fi
48+
49+
LOWER_BOUND=$(echo "$EXPECTED_SCORE * 0.975" | bc -l)
50+
UPPER_BOUND=$(echo "$EXPECTED_SCORE * 1.025" | bc -l)
51+
PASSED=$(echo "$SCORE >= $LOWER_BOUND && $SCORE <= $UPPER_BOUND" | bc -l)
52+
53+
if [ "$PASSED" -eq 1 ]; then
54+
echo "$NAME: PASS ✅ (Actual Score: $SCORE, Expected Score: $EXPECTED_SCORE)"
55+
else
56+
echo "$NAME: FAIL ❌ (Actual Score: $SCORE, Expected Score: $EXPECTED_SCORE)"
57+
fi
58+
done

src/rank_llm/rerank/reranker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ def create_model_coordinator(
342342
("prompt_mode", PromptMode.MONOT5),
343343
(
344344
"prompt_template_path",
345-
None,
346-
), # TODO(issue #236): Need to modify and add default MONOT5 template
345+
"src/rank_llm/rerank/prompt_templates/monot5_template.yaml",
346+
),
347347
("context_size", 512),
348348
("num_few_shot_examples", 0),
349349
("few_shot_file", None),
@@ -375,7 +375,7 @@ def create_model_coordinator(
375375
batch_size=batch_size,
376376
)
377377
elif "duot5" in model_path:
378-
# using monot5
378+
# using duot5
379379
print(f"Loading {model_path} ...")
380380

381381
model_full_paths = {"duot5": "castorini/duot5-3b-msmarco-10k"}
@@ -384,8 +384,8 @@ def create_model_coordinator(
384384
("prompt_mode", PromptMode.DUOT5),
385385
(
386386
"prompt_template_path",
387-
None,
388-
), # TODO(issue #236): Need to modify and add default DUOT5 template
387+
"src/rank_llm/rerank/prompt_templates/duot5_template.yaml",
388+
),
389389
("context_size", 512),
390390
("device", "cuda"),
391391
("batch_size", 64),

src/rank_llm/retrieve_and_rerank.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from typing import Any, Dict, List, Union
2+
from typing import Any, Dict, List, Optional, Union
33

44
from rank_llm.data import Query, Request
55
from rank_llm.rerank import IdentityReranker, RankLLM, Reranker
@@ -21,6 +21,7 @@ def retrieve_and_rerank(
2121
retrieval_method: RetrievalMethod = RetrievalMethod.BM25,
2222
top_k_retrieve: int = 50,
2323
top_k_rerank: int = 10,
24+
max_queries: Optional[int] = None,
2425
shuffle_candidates: bool = False,
2526
print_prompts_responses: bool = False,
2627
qid: int = 1,
@@ -58,6 +59,9 @@ def retrieve_and_rerank(
5859
**kwargs,
5960
)
6061

62+
if max_queries is not None:
63+
requests = requests[: min(len(requests), max_queries)]
64+
6165
for request in requests:
6266
request.candidates = request.candidates[:top_k_retrieve]
6367

src/rank_llm/scripts/run_rank_llm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def main(args):
2222
context_size = args.context_size
2323
top_k_candidates = args.top_k_candidates
2424
top_k_rerank = top_k_candidates if args.top_k_rerank == -1 else args.top_k_rerank
25+
max_queries = args.max_queries
2526
dataset = args.dataset
2627
num_gpus = args.num_gpus
2728
retrieval_method = args.retrieval_method
@@ -56,6 +57,7 @@ def main(args):
5657
retrieval_method=retrieval_method,
5758
top_k_retrieve=top_k_candidates,
5859
top_k_rerank=top_k_rerank,
60+
max_queries=max_queries,
5961
context_size=context_size,
6062
device=device,
6163
num_gpus=num_gpus,
@@ -119,6 +121,12 @@ def main(args):
119121
default=-1,
120122
help="the number of top candidates to return from reranking",
121123
)
124+
parser.add_argument(
125+
"--max_queries",
126+
type=int,
127+
default=None,
128+
help="the max number of queries to process from the dataset",
129+
)
122130
parser.add_argument(
123131
"--dataset",
124132
type=str,

0 commit comments

Comments
 (0)