Skip to content

Commit 9c5759c

Browse files
committed
Rank-K-32B Integration
1 parent 33f4983 commit 9c5759c

File tree

4 files changed

+146
-1
lines changed

4 files changed

+146
-1
lines changed

src/rank_llm/rerank/listwise/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .rank_gemini import SafeGenai
22
from .rank_gpt import SafeOpenai
33
from .rank_listwise_os_llm import RankListwiseOSLLM
4+
from .rankk_reranker import RankKReranker
45
from .vicuna_reranker import VicunaReranker
56
from .zephyr_reranker import ZephyrReranker
67

@@ -10,4 +11,5 @@
1011
"ZephyrReranker",
1112
"SafeOpenai",
1213
"SafeGenai",
14+
"RankKReranker",
1315
]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Optional
2+
3+
from rank_llm.data import Result
4+
from rank_llm.rerank import PromptMode
5+
from rank_llm.rerank.listwise import RankListwiseOSLLM
6+
7+
8+
class RankKReranker(RankListwiseOSLLM):
9+
def __init__(
10+
self,
11+
model: str = "hltcoe/Rank-K-32B",
12+
context_size: int = 4096,
13+
prompt_mode: PromptMode = PromptMode.RANK_GPT,
14+
prompt_template_path: Optional[
15+
str
16+
] = "src/rank_llm/rerank/prompt_templates/rank_k_template.yaml",
17+
num_few_shot_examples: int = 0,
18+
device: str = "cuda",
19+
num_gpus: int = 1,
20+
variable_passages: bool = True,
21+
window_size: int = 20,
22+
use_alpha: bool = False,
23+
) -> None:
24+
super().__init__(
25+
model=model,
26+
context_size=context_size,
27+
prompt_mode=prompt_mode,
28+
prompt_template_path=prompt_template_path,
29+
num_few_shot_examples=num_few_shot_examples,
30+
device=device,
31+
num_gpus=num_gpus,
32+
variable_passages=variable_passages,
33+
is_thinking=True,
34+
reasoning_token_budget=10000,
35+
window_size=window_size,
36+
use_alpha=use_alpha,
37+
)
38+
39+
def receive_permutation(
40+
self,
41+
result: Result,
42+
permutation: str,
43+
rank_start: int,
44+
rank_end: int,
45+
logging: bool = False,
46+
) -> Result:
47+
"""
48+
Processes and applies a permutation to the ranking results.
49+
50+
This function takes a permutation string, representing the new order of items,
51+
and applies it to a subset of the ranking results. It adjusts the ranks and scores in the
52+
'result' object based on this permutation.
53+
54+
Args:
55+
result (Result): The result object containing the initial ranking results.
56+
permutation (str): A string representing the new order of items.
57+
Each item in the string should correspond to a rank in the results.
58+
rank_start (int): The starting index of the range in the results to which the permutation is applied.
59+
rank_end (int): The ending index of the range in the results to which the permutation is applied.
60+
61+
Returns:
62+
Result: The updated result object with the new ranking order applied.
63+
64+
Note:
65+
This function assumes that the permutation string has reasoning generated by Rank-K-32B preceding
66+
the sequence of integers separated by spaces.
67+
The function would take the last line of input, and call receive_permutation function from
68+
the superclass.
69+
"""
70+
71+
# Remove all the reasoning, and take only the order
72+
permutation = permutation.strip().split("\n")[-1]
73+
return super().receive_permutation(
74+
result, permutation, rank_start, rank_end, logging
75+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
method: "singleturn_listwise"
2+
prefix: |-
3+
Determine a ranking of the passages based on how relevant they are to the query.
4+
If the query is a question, how relevant a passage is depends on how well it answers the question.
5+
If not, try analyze the intent of the query and assess how well each passage satisfy the intent.
6+
The query may have typos and passages may contain contradicting information.
7+
However, we do not get into fact-checking. We just rank the passages based on they relevancy to the query.
8+
9+
Sort them from the most relevant to the least.
10+
Answer with the passage number using a format of `[3] > [2] > [4] = [1] > [5]`.
11+
Ties are acceptable if they are equally relevant.
12+
I need you to be accurate but overthinking it is unnecessary.
13+
Output only the ordering without any other text.
14+
15+
Query: {query}
16+
body: "\n\n[{rank}] {candidate}"
17+
suffix: ""

src/rank_llm/rerank/reranker.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
get_genai_api_key,
1010
get_openai_api_key,
1111
)
12-
from rank_llm.rerank.listwise import RankListwiseOSLLM, SafeGenai, SafeOpenai
12+
from rank_llm.rerank.listwise import (
13+
RankKReranker,
14+
RankListwiseOSLLM,
15+
SafeGenai,
16+
SafeOpenai,
17+
)
1318
from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore
1419
from rank_llm.rerank.pairwise.duot5 import DuoT5
1520
from rank_llm.rerank.pointwise.monot5 import MonoT5
@@ -487,6 +492,52 @@ def create_model_coordinator(
487492
elif model_path in ["unspecified", "rank_random", "rank_identity"]:
488493
# NULL reranker
489494
agent = None
495+
elif "hltcoe/Rank-K-32B" in model_path:
496+
print(f"Loading {model_path} ...")
497+
keys_and_defaults = [
498+
("context_size", 4096),
499+
("prompt_mode", PromptMode.RANK_GPT),
500+
(
501+
"prompt_template_path",
502+
"src/rank_llm/rerank/prompt_templates/rank_k_template.yaml",
503+
),
504+
("num_few_shot_examples", 0),
505+
("device", "cuda"),
506+
("num_gpus", 1),
507+
("variable_passages", False),
508+
("window_size", 20),
509+
("system_message", None),
510+
("use_logits", False),
511+
("use_alpha", False),
512+
]
513+
[
514+
context_size,
515+
prompt_mode,
516+
prompt_template_path,
517+
num_few_shot_examples,
518+
device,
519+
num_gpus,
520+
variable_passages,
521+
window_size,
522+
system_message,
523+
use_logits,
524+
use_alpha,
525+
] = extract_kwargs(keys_and_defaults, **kwargs)
526+
527+
model_coordinator = RankKReranker(
528+
model=(model_path),
529+
context_size=context_size,
530+
prompt_mode=prompt_mode,
531+
prompt_template_path=prompt_template_path,
532+
num_few_shot_examples=num_few_shot_examples,
533+
device=device,
534+
num_gpus=num_gpus,
535+
variable_passages=variable_passages,
536+
window_size=window_size,
537+
use_alpha=use_alpha,
538+
)
539+
540+
print(f"Completed loading {model_path}")
490541
else:
491542
# supports loading models from huggingface
492543
print(f"Loading {model_path} ...")

0 commit comments

Comments
 (0)