Skip to content

Commit 11f8daf

Browse files
committed
reformatted
1 parent c9c7d62 commit 11f8daf

File tree

3 files changed

+41
-28
lines changed

3 files changed

+41
-28
lines changed
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
from .rank_gemini import SafeGenai
22
from .rank_gpt import SafeOpenai
33
from .rank_listwise_os_llm import RankListwiseOSLLM
4+
from .rankk_reranker import RankKReranker
45
from .vicuna_reranker import VicunaReranker
56
from .zephyr_reranker import ZephyrReranker
6-
from .rankk_reranker import RankKReranker
77

88
__all__ = [
99
"RankListwiseOSLLM",
1010
"VicunaReranker",
1111
"ZephyrReranker",
1212
"SafeOpenai",
1313
"SafeGenai",
14-
"RankKReranker"
14+
"RankKReranker",
1515
]
16-
Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,48 @@
11
from typing import Optional
22

3-
from rank_llm.data import Result
3+
from rank_llm.data import Result
44
from rank_llm.rerank import PromptMode
55
from rank_llm.rerank.listwise import RankListwiseOSLLM
66

77

88
class RankKReranker(RankListwiseOSLLM):
99
def __init__(
10-
self,
11-
model: str = "hltcoe/Rank-K-32B",
12-
context_size: int = 4096,
13-
prompt_mode: PromptMode = PromptMode.RANK_GPT,
14-
prompt_template_path: Optional[str] = "src/rank_llm/rerank/prompt_templates/rank_k_template.yaml",
15-
num_few_shot_examples: int = 0,
16-
device: str = "cuda",
17-
num_gpus: int = 1,
18-
variable_passages: bool = True,
19-
window_size: int = 20,
20-
use_alpha: bool = False
10+
self,
11+
model: str = "hltcoe/Rank-K-32B",
12+
context_size: int = 4096,
13+
prompt_mode: PromptMode = PromptMode.RANK_GPT,
14+
prompt_template_path: Optional[
15+
str
16+
] = "src/rank_llm/rerank/prompt_templates/rank_k_template.yaml",
17+
num_few_shot_examples: int = 0,
18+
device: str = "cuda",
19+
num_gpus: int = 1,
20+
variable_passages: bool = True,
21+
window_size: int = 20,
22+
use_alpha: bool = False,
2123
) -> None:
2224
super().__init__(
2325
model=model,
2426
context_size=context_size,
2527
prompt_mode=prompt_mode,
2628
prompt_template_path=prompt_template_path,
27-
num_few_shot_examples = num_few_shot_examples,
28-
device = device,
29+
num_few_shot_examples=num_few_shot_examples,
30+
device=device,
2931
num_gpus=num_gpus,
30-
variable_passages= variable_passages,
32+
variable_passages=variable_passages,
3133
is_thinking=True,
3234
reasoning_token_budget=10000,
3335
window_size=window_size,
34-
use_alpha=use_alpha)
35-
36+
use_alpha=use_alpha,
37+
)
38+
3639
def receive_permutation(
37-
self, result: Result, permutation: str, rank_start: int, rank_end: int, logging: bool = False
40+
self,
41+
result: Result,
42+
permutation: str,
43+
rank_start: int,
44+
rank_end: int,
45+
logging: bool = False,
3846
) -> Result:
3947
"""
4048
Processes and applies a permutation to the ranking results.
@@ -54,13 +62,14 @@ def receive_permutation(
5462
Result: The updated result object with the new ranking order applied.
5563
5664
Note:
57-
This function assumes that the permutation string has reasoning generated by Rank-K-32B preceding
65+
This function assumes that the permutation string has reasoning generated by Rank-K-32B preceding
5866
the sequence of integers separated by spaces.
59-
The function would take the last line of input, and call receive_permutation function from
60-
the superclass.
67+
The function would take the last line of input, and call receive_permutation function from
68+
the superclass.
6169
"""
6270

6371
# Remove all the reasoning, and take only the order
6472
permutation = permutation.strip().split("\n")[-1]
65-
return super().receive_permutation(result, permutation, rank_start, rank_end, logging)
66-
73+
return super().receive_permutation(
74+
result, permutation, rank_start, rank_end, logging
75+
)

src/rank_llm/rerank/reranker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
get_genai_api_key,
1010
get_openai_api_key,
1111
)
12-
from rank_llm.rerank.listwise import RankListwiseOSLLM, SafeGenai, SafeOpenai, RankKReranker
12+
from rank_llm.rerank.listwise import (
13+
RankKReranker,
14+
RankListwiseOSLLM,
15+
SafeGenai,
16+
SafeOpenai,
17+
)
1318
from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore
1419
from rank_llm.rerank.pairwise.duot5 import DuoT5
1520
from rank_llm.rerank.pointwise.monot5 import MonoT5
@@ -523,7 +528,7 @@ def create_model_coordinator(
523528
model=(model_path),
524529
context_size=context_size,
525530
prompt_mode=prompt_mode,
526-
prompt_template_path = prompt_template_path,
531+
prompt_template_path=prompt_template_path,
527532
num_few_shot_examples=num_few_shot_examples,
528533
device=device,
529534
num_gpus=num_gpus,

0 commit comments

Comments
 (0)