Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"python-dotenv>=1.1.1",
"qwen-agent[gui,mcp,rag]==0.0.27",
"qwen-omni-utils==0.0.8",
"rank-llm>=0.21.0",
"rich>=14.0.0",
"tevatron",
"torchvision",
Expand Down
4 changes: 2 additions & 2 deletions scripts_evaluation/evaluate_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,9 @@ def main():
output_dir = mirror_directory_structure(input_dir, eval_dir)
print(f"Evaluations will be saved to {output_dir}")

json_files = list(input_dir.glob("*.json"))
json_files = list(input_dir.glob("run_*.json"))
if not json_files:
print(f"No JSON files found in {input_dir}")
print(f"No JSON files starting with 'run_' found in {input_dir}")
return

print(f"Found {len(json_files)} JSON files to evaluate")
Expand Down
6 changes: 4 additions & 2 deletions scripts_evaluation/evaluate_with_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
import openai
from dotenv import load_dotenv
from tqdm import tqdm

sys.path.append(str(Path(__file__).parent.parent))
Expand Down Expand Up @@ -431,16 +432,17 @@ def main():
output_dir = mirror_directory_structure(input_dir, eval_dir)
print(f"Evaluations will be saved to {output_dir}")

json_files = list(input_dir.glob("*.json"))
json_files = list(input_dir.glob("run_*.json"))
if not json_files:
print(f"No JSON files found in {input_dir}")
print(f"No JSON files starting with 'run_' found in {input_dir}")
return

print(f"Found {len(json_files)} JSON files to evaluate")

all_results = []
skipped = 0

load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY is not set in environment")
Expand Down
Empty file.
142 changes: 142 additions & 0 deletions scripts_retrieval_only/retrieve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import argparse
import csv
import os
import sys
from pathlib import Path
from typing import List, Tuple

from tqdm import tqdm

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from searcher.rerankers import RerankerType
from searcher.searchers import BaseSearcher, SearcherType


def load_queries(tsv_path: str) -> List[Tuple[str, str]]:
"""Loads queries from a TSV file."""
if not tsv_path.strip().lower().endswith(".tsv"):
raise ValueError(f"Invalid query file format, expected .tsv: {tsv_path}")
dataset_path = Path(tsv_path)
if not dataset_path.is_file():
raise FileNotFoundError(f"TSV query file not found: {tsv_path}")
queries = []
with dataset_path.open(newline="", encoding="utf-8") as f:
reader = csv.reader(f, delimiter="\t")
for row in reader:
if len(row) < 2:
continue
queries.append((row[0].strip(), row[1].strip()))
return queries


def process_queries(
searcher: BaseSearcher, queries: List[Tuple[str, str]], args: argparse.Namespace
):
"""
Processes queries in batches using retrieve_batch (dict[qid -> candidates]),
optionally reranks per batch, and streams results to file to avoid OOM.
"""
batch_size = args.batch_size

if getattr(args, "reranker_type", None):
model_name = (
"bm25" if args.searcher_type == "bm25" else args.model_name.split("/")[-1]
)
filename = (
f"retrieve_{model_name}_k_{args.k}_"
f"rerank_{args.reranker_model.split('/')[-1]}_k_{args.first_stage_k}.trec"
)
else:
filename = f"retrieve_{args.searcher_type}_k_{args.k}.trec"

os.makedirs(args.output_dir, exist_ok=True)
filepath = os.path.join(args.output_dir, filename)
reranker = None
if getattr(args, "reranker_type", None):
reranker_class = RerankerType.get_reranker_class(args.reranker_type)
reranker = reranker_class(args)

print(
f"Retrieving{' + reranking' if reranker else ''} in batches of {batch_size}..."
)
with open(filepath, "w", encoding="utf-8") as f:
for start in tqdm(range(0, len(queries), batch_size)):
end = min(start + batch_size, len(queries))
batch = queries[start:end]
batch_qids = [qid for qid, _ in batch]
batch_qtexts = [qtext for _, qtext in batch]
retrieved = searcher.retrieve_batch(batch_qtexts, batch_qids, args.k)
if reranker is not None:
Path(f"{args.output_dir}/invocaton_history").mkdir(
parents=True, exist_ok=True
)
history_file_name = f"{args.output_dir}/invocaton_history/{filename[:-5]}_{start}_{end}.json"
batch_queries_dict = {qid: qtext for qid, qtext in batch}
rerank_results = reranker.rerank_batch(
batch_queries_dict, retrieved, history_file_name, args.first_stage_k
)
retrieved = {}
for i, qid in enumerate(batch_qids):
retrieved[qid] = rerank_results[i]
for qid, _ in batch:
candidates = retrieved.get(qid, [])
for rank, cand in enumerate(candidates, start=1):
f.write(
f'{qid} Q0 {cand["docid"]} {rank} {cand["score"]} {args.searcher_type}\n'
)
del retrieved

print(f"Done. Wrote results to: {filepath}")


def main():
"""Main function to handle argument parsing and query processing."""
parser = argparse.ArgumentParser(description="Retrieval with optional reranking")
parser.add_argument(
"--query-file",
default="topics-qrels/queries.tsv",
help="The .tsv file containing queries.",
)
parser.add_argument(
"--output-dir", required=True, help="Directory to store retrieved results."
)
parser.add_argument(
"--searcher-type",
choices=SearcherType.get_choices(),
required=True,
help=f"Type of searcher to use: {', '.join(SearcherType.get_choices())}.",
)
parser.add_argument(
"--reranker-type",
choices=RerankerType.get_choices(),
default=None,
help=f"Type of reranker to use: None, {', '.join(RerankerType.get_choices())}. Default: None.",
)
parser.add_argument(
"--k",
type=int,
default=5,
help="Fixed number of search results to return for all queries (default: 5).",
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
help="The batch size used for retreival and optionally reranking. (default: 64).",
)
temp_args, _ = parser.parse_known_args()
searcher_class = SearcherType.get_searcher_class(temp_args.searcher_type)
searcher_class.parse_args(parser)
if temp_args.reranker_type:
reranker_class = RerankerType.get_reranker_class(temp_args.reranker_type)
reranker_class.parse_args(parser)
args = parser.parse_args()
searcher = searcher_class(
reranker=None, args=args
) # process queries will batch rerank after batch retrieval, don't pass the reranker here.
queries = load_queries(args.query_file)
process_queries(searcher, queries, args)


if __name__ == "__main__":
main()
51 changes: 35 additions & 16 deletions search_agent/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path

import openai
from dotenv import load_dotenv
from prompts import format_query
from rich import print as rprint
from tqdm import tqdm
Expand All @@ -17,6 +18,7 @@
from transformers import AutoTokenizer
from utils import extract_retrieved_docids_from_result

from searcher.rerankers import RerankerType
from searcher.searchers import SearcherType


Expand Down Expand Up @@ -83,14 +85,14 @@ def get_tool_definitions(self):

def execute_tool(self, tool_name: str, arguments: dict):
if tool_name == "search":
return self._search(arguments["query"])
return self._search(arguments["query"], arguments["query_id"])
elif tool_name == "get_document":
return self._get_document(arguments["docid"])
else:
raise ValueError(f"Unknown tool: {tool_name}")

def _search(self, query: str):
candidates = self.searcher.search(query, self.k)
def _search(self, query: str, query_id: str | None):
candidates = self.searcher.search(query, query_id, self.k)

if self.snippet_max_tokens and self.snippet_max_tokens > 0 and self.tokenizer:
for cand in candidates:
Expand Down Expand Up @@ -119,7 +121,6 @@ def _search(self, query: str):
"snippet": cand["snippet"],
}
)

return json.dumps(results, indent=2)

def _get_document(self, docid: str):
Expand Down Expand Up @@ -172,9 +173,9 @@ def run_conversation_with_tools(
client: openai.OpenAI,
initial_request: dict,
tool_handler: SearchToolHandler,
query_id: str | None = None,
max_iterations: int = 100,
):

input_messages = initial_request["input"].copy()
global_max_tokens = initial_request["max_output_tokens"]

Expand Down Expand Up @@ -231,7 +232,6 @@ def run_conversation_with_tools(
for item in response.output
if getattr(item, "type", None) == "function_call"
]

if not function_calls:
return response, combined_output, cumulative_usage, tool_outputs

Expand All @@ -243,6 +243,7 @@ def run_conversation_with_tools(
for tool_call in function_calls:
try:
args = json.loads(tool_call.arguments)
args["query_id"] = query_id
result = tool_handler.execute_tool(tool_call.name, args)

tool_outputs[tool_call.id] = {
Expand Down Expand Up @@ -447,10 +448,13 @@ def _handle_single_query(qid: str, qtext: str, pbar=None):
)

try:
response, combined_output, cumulative_usage, tool_outputs = (
run_conversation_with_tools(
client, request_body, tool_handler, args.max_iterations
)
(
response,
combined_output,
cumulative_usage,
tool_outputs,
) = run_conversation_with_tools(
client, request_body, tool_handler, qid, args.max_iterations
)

if response.status == "completed":
Expand Down Expand Up @@ -568,6 +572,12 @@ def main():
required=True,
help=f"Type of searcher to use: {', '.join(SearcherType.get_choices())}",
)
parser.add_argument(
"--reranker-type",
choices=RerankerType.get_choices(),
default=None,
help=f"Type of reranker to use: None, {', '.join(SearcherType.get_choices())}",
)

# Server configuration arguments
parser.add_argument(
Expand Down Expand Up @@ -599,6 +609,12 @@ def main():
)

temp_args, _ = parser.parse_known_args()
reranker = None
if temp_args.reranker_type:
reranker_class = RerankerType.get_reranker_class(temp_args.reranker_type)
reranker_class.parse_args(parser)
rerank_args, _ = parser.parse_known_args()
reranker = reranker_class(rerank_args)
searcher_class = SearcherType.get_searcher_class(temp_args.searcher_type)
searcher_class.parse_args(parser)

Expand All @@ -613,13 +629,13 @@ def main():
print(f"[DEBUG] Setting HF home from CLI argument: {args.hf_home}")
os.environ["HF_HOME"] = args.hf_home

load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY is not set in environment")

client = openai.OpenAI(api_key=api_key)

searcher = searcher_class(args)
searcher = searcher_class(reranker, args)

tool_handler = SearchToolHandler(
searcher=searcher,
Expand Down Expand Up @@ -664,10 +680,13 @@ def main():
)

print("Sending request to OpenAI Responses API with function calling...")
response, combined_output, cumulative_usage, tool_outputs = (
run_conversation_with_tools(
client, request_body, tool_handler, args.max_iterations
)
(
response,
combined_output,
cumulative_usage,
tool_outputs,
) = run_conversation_with_tools(
client, request_body, tool_handler, None, args.max_iterations
)

_persist_response(
Expand Down
Loading