Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/rank_llm/rerank/pointwise/bge_reranker_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import logging
import math
from typing import List, Tuple
import torch

from FlagEmbedding import LayerWiseFlagLLMReranker, FlagLLMReranker, FlagReranker

from rank_llm.data import Result
from rank_llm.rerank.pointwise.pointwise_rankllm import PointwiseRankLLM

logger = logging.getLogger(__name__).setLevel(logging.WARNING)


class BGE_RERANKER_V2(PointwiseRankLLM):
def __init__(
self,
model: str,
prompt_mode: str = "bge-reranker-v2",
context_size: int = 8192,
device: str = "cuda",
batch_size: int = 32,
use_fp16: bool = False
):
super().__init__(
model=model,
context_size=context_size,
prompt_mode=prompt_mode,
device=device,
batch_size=batch_size,
)

if "base" in self._model or "large" in self._model or "m3" in self._model:
self._llm = FlagReranker(self._model, use_fp16=use_fp16)
elif "minicpm-layerwise" in self._model:
self._llm=LayerWiseFlagLLMReranker(self._model, use_fp16=use_fp16)
elif "gemma" in self._model:
self._llm=FlagLLMReranker(self._model, use_fp16=use_fp16)
else:
raise ValueError("Given bge model doesn't exist or isn't supported in rank_llm.")

def run_llm_batched(
self,
prompts: List[List[str]],
) -> Tuple[None, None, List[float]]:

all_outputs = None
all_output_token_counts = None
all_scores = []

pairs = prompts

with torch.no_grad():
if "base" in self._model or "large" in self._model or "m3" in self._model:
all_scores = self._llm.compute_score(pairs)

elif "gemma" in self._model:
all_scores = self._llm.compute_score(pairs)

elif "minicpm-layerwise" in self._model:
scores = self._llm.compute_score(pairs, cutoff_layers=[28])
if not isinstance(scores[0], float):
for score in scores[0]:
all_scores.append(score)
else:
all_scores = scores

return all_outputs, all_output_token_counts, all_scores

def run_llm(self, prompt: str) -> Tuple[None, None, float]:

outputs = None
output_ids = None

pair=prompt

with torch.no_grad():
if "base" in self._model or "large" in self._model or "m3" in self._model:
score = self._llm.compute_score(pair)

elif "gemma" in self._model:
score = self._llm.compute_score(pair)

elif "minicpm-layerwise" in self._model:
score = self._llm.compute_score(pair, cutoff_layers=[28])

return outputs, output_ids, score

def num_output_tokens(self) -> int:
return 1

def create_prompt(self, result: Result, index: int) -> Tuple[str, int]:
query = result.query.text
prompt = [query, self.convert_doc_to_prompt_content(result.candidates[index].doc, max_length=self._context_size)]

return prompt, None

def get_num_tokens(self, prompt: str) -> int:
return 1

def cost_per_1k_token(self, input_token: bool) -> float:
return 0
63 changes: 55 additions & 8 deletions src/rank_llm/rerank/pointwise/monot5.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
import logging
import math
from typing import List, Tuple

from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration
from transformers.generation import GenerationConfig

from rank_llm.data import Result
from rank_llm.rerank.pointwise.pointwise_rankllm import PointwiseRankLLM

logger = logging.getLogger(__name__)

PREDICTION_TOKENS = {
'castorini/monot5-base-msmarco': ['▁false', '▁true'],
'castorini/monot5-base-msmarco-10k': ['▁false', '▁true'],
'castorini/monot5-large-msmarco': ['▁false', '▁true'],
'castorini/monot5-large-msmarco-10k': ['▁false', '▁true'],
'castorini/monot5-base-med-msmarco': ['▁false', '▁true'],
'castorini/monot5-3b-med-msmarco': ['▁false', '▁true'],
'castorini/monot5-3b-msmarco-10k': ['▁false', '▁true'],
'unicamp-dl/mt5-base-en-msmarco': ['▁no' , '▁yes'],
'unicamp-dl/ptt5-base-pt-msmarco-10k-v2': ['▁não' , '▁sim'],
'unicamp-dl/ptt5-base-pt-msmarco-100k-v2': ['▁não' , '▁sim'],
'unicamp-dl/ptt5-base-en-pt-msmarco-100k-v2':['▁não' , '▁sim'],
'unicamp-dl/mt5-base-en-pt-msmarco-v2': ['▁no' , '▁yes'],
'unicamp-dl/mt5-base-mmarco-v2': ['▁no' , '▁yes'],
'unicamp-dl/mt5-base-en-pt-msmarco-v1': ['▁no' , '▁yes'],
'unicamp-dl/mt5-base-mmarco-v1': ['▁no' , '▁yes'],
'unicamp-dl/ptt5-base-pt-msmarco-10k-v1': ['▁não' , '▁sim'],
'unicamp-dl/ptt5-base-pt-msmarco-100k-v1': ['▁não' , '▁sim'],
'unicamp-dl/ptt5-base-en-pt-msmarco-10k-v1': ['▁não' , '▁sim'],
'unicamp-dl/mt5-3b-mmarco-en-pt': ['▁' , '▁true'],
'unicamp-dl/mt5-13b-mmarco-100k': ['▁', '▁true'],
}


class MonoT5(PointwiseRankLLM):
def __init__(
Expand All @@ -19,6 +42,7 @@ def __init__(
context_size: int = 512,
device: str = "cuda",
batch_size: int = 32,
dtype: str = "float16"
):
super().__init__(
model=model,
Expand All @@ -28,10 +52,26 @@ def __init__(
batch_size=batch_size,
)

self._tokenizer = T5Tokenizer.from_pretrained(model)
self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device)
if model.find("mt5") != -1:
self._tokenizer = MT5Tokenizer.from_pretrained(model)
dtype = torch.float16 if dtype == "float16" else torch.float32
self._llm = MT5ForConditionalGeneration.from_pretrained(model, torch_dtype=dtype).to(self._device)
elif model.find("monot5") != -1:
self._tokenizer = T5Tokenizer.from_pretrained(model)
self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device)
self._context_size = context_size

@staticmethod
def get_prediction_tokens(pretrained_model_name_or_path: str, tokenizer):
if pretrained_model_name_or_path in PREDICTION_TOKENS:
token_false, token_true = PREDICTION_TOKENS[pretrained_model_name_or_path]
token_false_id = tokenizer.get_vocab()[token_false]
token_true_id = tokenizer.get_vocab()[token_true]
return token_false_id, token_true_id
else:
raise Exception(f"We don't know the indexes for the non-relevant/relevant tokens for\
the checkpoint {pretrained_model_name_or_path} and you did not provide any.")
pass
def run_llm_batched(
self,
prompts: List[str],
Expand Down Expand Up @@ -70,10 +110,15 @@ def run_llm_batched(
]

for logit_tensor in batch_logits[0]:
truth_logit = logit_tensor[1176]
false_logit = logit_tensor[6136]
score = math.exp(truth_logit) / (
math.exp(truth_logit) + math.exp(false_logit)
token_false_id, token_true_id = self.get_prediction_tokens(self._model, self._tokenizer)
pos_index = logit_tensor[token_true_id]
neg_index = logit_tensor[token_false_id]
#wanted_logit = logit_tensor.find("")
#print(f"current False_logit is {neg_index}")
#print(f"wanted false logit can be found at {wanted_logit}")

score = math.exp(pos_index) / (
math.exp(pos_index) + math.exp(neg_index)
)
all_scores.append(score)
all_output_token_counts.append(self.num_output_tokens)
Expand All @@ -82,6 +127,8 @@ def run_llm_batched(

return all_outputs, all_output_token_counts, all_scores



def run_llm(self, prompt: str) -> Tuple[str, int, float]:
gen_cfg = GenerationConfig.from_model_config(self._llm.config)
gen_cfg.max_new_tokens = self.num_output_tokens()
Expand Down
1 change: 1 addition & 0 deletions src/rank_llm/rerank/rankllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class PromptMode(Enum):
LRL = "LRL"
MONOT5 = "monot5"
LiT5 = "LiT5"
BGE_RERANKER_V2= "bge-reranker-v2"

def __str__(self):
return self.value
Expand Down
55 changes: 55 additions & 0 deletions src/rank_llm/rerank/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rank_llm.rerank.listwise import RankListwiseOSLLM, SafeOpenai
from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore
from rank_llm.rerank.pointwise.monot5 import MonoT5
from rank_llm.rerank.pointwise.bge_reranker_v2 import BGE_RERANKER_V2
from rank_llm.rerank.rankllm import RankLLM


Expand Down Expand Up @@ -279,6 +280,30 @@ def create_agent(
batch_size=batch_size,
)

elif "mt5" in model_path:
# using monot5
print(f"Loading {model_path} ...")

keys_and_defaults = [
("prompt_mode", PromptMode.MONOT5),
("context_size", 512),
("device", "cuda"),
("batch_size", 64),
("dtype", "float32"),
]
[prompt_mode, context_size, device, batch_size, dtype] = extract_kwargs(
keys_and_defaults, **kwargs
)

agent = MonoT5(
model=model_path,
prompt_mode=prompt_mode,
context_size=context_size,
device=device,
batch_size=batch_size,
dtype=dtype
)

elif "lit5-distill" in model_path.lower():
keys_and_defaults = [
("context_size", 150),
Expand Down Expand Up @@ -345,6 +370,36 @@ def create_agent(
batched=vllm_batched,
)
print(f"Completed loading {model_path}")

elif "bge-reranker" in model_path.lower():
print(f"Loading {model_path} ...")

keys_and_defaults=[
("device", "cuda"),
("use_fp16", True),
("prompt_mode", PromptMode.BGE_RERANKER_V2),
("context_size", 8192),
("batch_size", 64)
]
(
device,
use_fp16,
prompt_mode,
context_size,
batch_size
) = extract_kwargs(keys_and_defaults, **kwargs)

agent = BGE_RERANKER_V2(
model=model_path,
prompt_mode=prompt_mode,
context_size=context_size,
use_fp16=use_fp16,
batch_size=batch_size,
device=device
)

print(f"Completed loading {model_path}")

elif model_path in ["unspecified", "rank_random", "rank_identity"]:
# NULL reranker
agent = None
Expand Down
8 changes: 8 additions & 0 deletions src/rank_llm/scripts/run_rank_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def main(args):
window_size = args.window_size
system_message = args.system_message
vllm_batched = args.vllm_batched
dtype = args.dtype

_ = retrieve_and_rerank(
model_path=model_path,
Expand All @@ -62,6 +63,7 @@ def main(args):
step_size=step_size,
system_message=system_message,
vllm_batched=vllm_batched,
dtype=dtype
)


Expand Down Expand Up @@ -175,5 +177,11 @@ def main(args):
action="store_true",
help="whether to run the model in batches",
)
parser.add_argument(
"--dtype",
default="float16",
type=str,
help="determine which dtype to do inference with, either float16 or float32",
)
args = parser.parse_args()
main(args)