Skip to content

Commit baf3b39

Browse files
authored
Pairwise inference handler implementation (#243)
- added pairwise-inference-handler and implemented it to rankllm and duot5 classes - fixed an existing bug in truncating doc1 and doc2 based on the remaining context size
1 parent 813fe2b commit baf3b39

File tree

6 files changed

+220
-54
lines changed

6 files changed

+220
-54
lines changed

src/rank_llm/rerank/pairwise/duot5.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import math
3+
import re
34
from typing import List, Optional, Tuple
45

56
from transformers import T5ForConditionalGeneration, T5Tokenizer
@@ -16,7 +17,7 @@ def __init__(
1617
self,
1718
model: str,
1819
prompt_mode: str = "duot5",
19-
prompt_template_path: Optional[str] = None,
20+
prompt_template_path: str = "src/rank_llm/rerank/prompt_templates/duot5_template.yaml",
2021
context_size: int = 512,
2122
num_few_shot_examples: int = 0,
2223
few_shot_file: Optional[str] = None,
@@ -101,7 +102,7 @@ def run_llm(self, prompt: str) -> Tuple[str, int, float]:
101102
def create_prompt(
102103
self, result: Result, index1: int, index2: int
103104
) -> Tuple[str, int]:
104-
query = self._replace_number(result.query.text)
105+
query = re.sub(r"\[(\d+)\]", r"(\1)", result.query.text)
105106

106107
reserved_for_output = (
107108
64 # might need to change depending on what the actual output look like
@@ -117,29 +118,14 @@ def create_prompt(
117118
self._context_size - reserved_for_output - query_tokens - few_shot_tokens
118119
)
119120

120-
doc1_raw = self.convert_doc_to_prompt_content(
121-
result.candidates[index1].doc, max_length=max_token
121+
# TODO (issue #237): need to modify the class to be able to add fewshot examples later
122+
prompt = self._inference_handler.generate_prompt(
123+
result=result,
124+
index1=index1,
125+
index2=index2,
126+
max_token=max_token,
127+
tokenizer=self._tokenizer,
122128
)
123-
doc2_raw = self.convert_doc_to_prompt_content(
124-
result.candidates[index2].doc, max_length=max_token
125-
)
126-
127-
doc1_tokens = self._tokenizer.encode(
128-
doc1_raw, truncation=True, max_length=max_token
129-
)
130-
doc2_tokens = self._tokenizer.encode(
131-
doc2_raw, truncation=True, max_length=max_token
132-
)
133-
134-
doc1 = self._tokenizer.decode(doc1_tokens, skip_special_tokens=True)
135-
doc2 = self._tokenizer.decode(doc2_tokens, skip_special_tokens=True)
136-
137-
prompt = (
138-
few_shot_prompt
139-
+ f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant: "
140-
)
141-
prompt = prompt.replace("<unk>", "")
142-
143129
return prompt, self.get_num_tokens(prompt)
144130

145131
def get_num_tokens(self, prompt: str) -> int:
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import Any, Dict
2+
3+
from transformers import T5Tokenizer
4+
5+
from rank_llm.data import Result
6+
from rank_llm.rerank.inference_handler import BaseInferenceHandler
7+
8+
9+
class PairwiseInferenceHandler(BaseInferenceHandler):
10+
def __init__(self, template: Dict[str, str]):
11+
super().__init__(template)
12+
13+
def _validate_template(self, template: Dict[str, str], strict: bool = False):
14+
TEMPLATE_SECTIONS = {
15+
# Format:
16+
# "template_key": {
17+
# "required": True/False, # Whether the section itself is mandatory
18+
# "required_placeholders": set(), # Placeholders that must exist in this section
19+
# "allowed_placeholders": set() # All allowed placeholders (including required ones)
20+
# }
21+
"body": {
22+
"required": True,
23+
"required_placeholders": {"query", "doc1", "doc2"},
24+
"allowed_placeholders": set(),
25+
},
26+
}
27+
28+
# Validate the method value
29+
if template["method"] != "pairwise":
30+
raise ValueError(
31+
f'Incorrect method type, expected "pairwise", got {template["method"]}'
32+
)
33+
34+
self._general_validation(
35+
template=template, template_section=TEMPLATE_SECTIONS, strict=strict
36+
)
37+
38+
# TODO (issue #273): May need to add prefix/suffix generation function later
39+
40+
def _generate_body(
41+
self,
42+
result: Result,
43+
index1: int,
44+
index2: int,
45+
single_doc_max_token: int,
46+
tokenizer: T5Tokenizer,
47+
) -> str:
48+
doc1_raw = self._convert_doc_to_prompt_content(
49+
result.candidates[index1].doc, max_length=single_doc_max_token
50+
)
51+
doc2_raw = self._convert_doc_to_prompt_content(
52+
result.candidates[index2].doc, max_length=single_doc_max_token
53+
)
54+
55+
doc1_tokens = tokenizer.encode(
56+
doc1_raw, truncation=True, max_length=single_doc_max_token
57+
)
58+
doc2_tokens = tokenizer.encode(
59+
doc2_raw, truncation=True, max_length=single_doc_max_token
60+
)
61+
62+
query = self._replace_number(result.query.text)
63+
doc1 = tokenizer.decode(doc1_tokens, skip_special_tokens=True)
64+
doc2 = tokenizer.decode(doc2_tokens, skip_special_tokens=True)
65+
66+
fmt_values = {"query": query, "doc1": doc1, "doc2": doc2}
67+
body_text = self._format_template(template_key="body", fmt_values=fmt_values)
68+
69+
return body_text
70+
71+
def generate_prompt(self, result: Result, **kwargs: Any) -> str:
72+
try:
73+
index1 = kwargs["index1"]
74+
index2 = kwargs["index2"]
75+
max_token = kwargs["max_token"]
76+
tokenizer = kwargs["tokenizer"]
77+
except KeyError as e:
78+
raise ValueError(f"Missing required parameter: {e}")
79+
80+
single_doc_max_token = max_token // 2
81+
82+
prompt = self._generate_body(
83+
result=result,
84+
index1=index1,
85+
index2=index2,
86+
single_doc_max_token=single_doc_max_token,
87+
tokenizer=tokenizer,
88+
)
89+
return prompt.replace("<unk>", "")

src/rank_llm/rerank/pairwise/pairwise_rankllm.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from abc import ABC
66
from datetime import datetime
77
from functools import cmp_to_key
8-
from typing import Any, Dict, List, Optional, Tuple
8+
from typing import Any, List, Optional, Tuple
99

10-
from ftfy import fix_text
1110
from tqdm import tqdm
1211

1312
from rank_llm.data import Candidate, Request, Result
@@ -172,32 +171,6 @@ def candidate_comparator(self, x: Candidate, y: Candidate) -> int:
172171
else:
173172
return 0
174173

175-
def _replace_number(self, s: str) -> str:
176-
return re.sub(r"\[(\d+)\]", r"(\1)", s)
177-
178-
def convert_doc_to_prompt_content(
179-
self, doc: Dict[str, Any], max_length: int
180-
) -> str:
181-
if "text" in doc:
182-
content = doc["text"]
183-
elif "segment" in doc:
184-
content = doc["segment"]
185-
elif "contents" in doc:
186-
content = doc["contents"]
187-
elif "content" in doc:
188-
content = doc["content"]
189-
elif "body" in doc:
190-
content = doc["body"]
191-
else:
192-
content = doc["passage"]
193-
if "title" in doc and doc["title"]:
194-
content = "Title: " + doc["title"] + " " + "Content: " + content
195-
content = content.strip()
196-
content = fix_text(content)
197-
# For Japanese should cut by character: content = content[:int(max_length)]
198-
content = " ".join(content.split()[: int(max_length)])
199-
return self._replace_number(content)
200-
201174
def _build_pairwise_few_shot_examples(self) -> str:
202175
if self._num_few_shot_examples > 0 and hasattr(self, "_examples"):
203176
examples = []
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
method: "pairwise"
2+
body: "Query: {query} Document0: {doc1} Document1: {doc2} Relevant: "

src/rank_llm/rerank/rankllm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151

5252
if bool(data):
5353
self._inference_handler = self._create_handler(data)
54+
print(f"Successfully created {data['method']} inference handler!")
5455

5556
if self._num_few_shot_examples > 0:
5657
if not few_shot_file:
@@ -211,6 +212,9 @@ def _create_handler(self, template: Dict[str, str]) -> BaseInferenceHandler:
211212
from rank_llm.rerank.listwise.singleturn_listwise_inference_handler import (
212213
SingleTurnListwiseInferenceHandler,
213214
)
215+
from rank_llm.rerank.pairwise.pairwise_inference_handler import (
216+
PairwiseInferenceHandler,
217+
)
214218
from rank_llm.rerank.pointwise.pointwise_inference_handler import (
215219
PointwiseInferenceHandler,
216220
)
@@ -222,8 +226,10 @@ def _create_handler(self, template: Dict[str, str]) -> BaseInferenceHandler:
222226
return MultiTurnListwiseInferenceHandler(template)
223227
elif template["method"] == "pointwise":
224228
return PointwiseInferenceHandler(template)
225-
else: # TODO(issue #236 and #237): Need to remove this after all the handlers are implemented
226-
return SingleTurnListwiseInferenceHandler(template)
229+
elif template["method"] == "pairwise":
230+
return PairwiseInferenceHandler(template)
231+
else:
232+
raise ValueError("Invalid template method")
227233
except:
228234
raise ValueError("Please provide a method section in the template")
229235

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import unittest
2+
3+
from dacite import from_dict
4+
from transformers import T5Tokenizer
5+
6+
from rank_llm.data import Result
7+
from rank_llm.rerank.pairwise.pairwise_inference_handler import PairwiseInferenceHandler
8+
9+
r = from_dict(
10+
data_class=Result,
11+
data={
12+
"query": {"text": "Sample Query", "qid": "q1"},
13+
"candidates": [
14+
{
15+
"doc": {
16+
"contents": "Title1: Sample Title1 Content1: Sample Text1",
17+
},
18+
"docid": "d1",
19+
"score": 0.5,
20+
},
21+
{
22+
"doc": {
23+
"contents": "Title2: Sample Title2 Content2: Sample Text2",
24+
},
25+
"docid": "d2",
26+
"score": 0.4,
27+
},
28+
{
29+
"doc": {
30+
"contents": "Title3: Sample Title3 Content3: Sample Text3",
31+
},
32+
"docid": "d3",
33+
"score": 0.4,
34+
},
35+
{
36+
"doc": {
37+
"contents": "Title4: Sample Title4 Content4: Sample Text4",
38+
},
39+
"docid": "d4",
40+
"score": 0.3,
41+
},
42+
],
43+
},
44+
)
45+
46+
47+
VALID_PAIRWISE_TEMPLATE = {
48+
"method": "pairwise",
49+
"body": "Query: {query} Document0: {doc1} Document1: {doc2}",
50+
}
51+
INVALID_PAIRWISE_TEMPLATES = [
52+
{
53+
"method": "singleturn_listwise",
54+
"body": "{query} {doc1} {doc2}",
55+
}, # Wrong method type
56+
{
57+
"method": "pairwise",
58+
"body": "{query} {doc1}",
59+
}, # Missing required placeholder: {doc2}
60+
{
61+
"method": "pairwise",
62+
"body": "{query} {doc1} {doc2}",
63+
"unknown_key": "value",
64+
}, # Unknown key
65+
]
66+
tokenizer = T5Tokenizer.from_pretrained("castorini/duot5-3b-msmarco-10k")
67+
68+
69+
class TestPairwiseInferenceHandler(unittest.TestCase):
70+
def test_pairwise_valid_template_initialization(self):
71+
pairwise_inference_handler = PairwiseInferenceHandler(VALID_PAIRWISE_TEMPLATE)
72+
self.assertEqual(pairwise_inference_handler.template, VALID_PAIRWISE_TEMPLATE)
73+
74+
def test_invalid_templates(self):
75+
for template in INVALID_PAIRWISE_TEMPLATES:
76+
with self.subTest(template=template):
77+
with self.assertRaises(ValueError):
78+
PairwiseInferenceHandler(template)
79+
80+
def test_body_generation(self):
81+
pairwise_inference_handler = PairwiseInferenceHandler(VALID_PAIRWISE_TEMPLATE)
82+
body_text_1 = pairwise_inference_handler._generate_body(
83+
result=r, index1=0, index2=1, single_doc_max_token=6000, tokenizer=tokenizer
84+
)
85+
body_text_2 = pairwise_inference_handler._generate_body(
86+
result=r, index1=0, index2=2, single_doc_max_token=6000, tokenizer=tokenizer
87+
)
88+
expected_body_1 = "Query: Sample Query Document0: Title1: Sample Title1 Content1: Sample Text1 Document1: Title2: Sample Title2 Content2: Sample Text2"
89+
expected_body_2 = "Query: Sample Query Document0: Title1: Sample Title1 Content1: Sample Text1 Document1: Title3: Sample Title3 Content3: Sample Text3"
90+
91+
self.assertEqual(body_text_1, expected_body_1)
92+
self.assertEqual(body_text_2, expected_body_2)
93+
94+
def test_prompt_generation(self):
95+
pairwise_inference_handler = PairwiseInferenceHandler(VALID_PAIRWISE_TEMPLATE)
96+
prompt_text_1 = pairwise_inference_handler.generate_prompt(
97+
result=r, index1=0, index2=1, max_token=6000, tokenizer=tokenizer
98+
)
99+
prompt_text_2 = pairwise_inference_handler.generate_prompt(
100+
result=r, index1=0, index2=2, max_token=6000, tokenizer=tokenizer
101+
)
102+
expected_prompt_1 = "Query: Sample Query Document0: Title1: Sample Title1 Content1: Sample Text1 Document1: Title2: Sample Title2 Content2: Sample Text2"
103+
expected_prompt_2 = "Query: Sample Query Document0: Title1: Sample Title1 Content1: Sample Text1 Document1: Title3: Sample Title3 Content3: Sample Text3"
104+
105+
self.assertEqual(prompt_text_1, expected_prompt_1)
106+
self.assertEqual(prompt_text_2, expected_prompt_2)
107+
108+
109+
if __name__ == "__main__":
110+
unittest.main()

0 commit comments

Comments
 (0)