diff --git a/src/rank_llm/rerank/pointwise/bge_reranker_v2.py b/src/rank_llm/rerank/pointwise/bge_reranker_v2.py new file mode 100644 index 00000000..3ba5fb50 --- /dev/null +++ b/src/rank_llm/rerank/pointwise/bge_reranker_v2.py @@ -0,0 +1,101 @@ +import logging +import math +from typing import List, Tuple +import torch + +from FlagEmbedding import LayerWiseFlagLLMReranker, FlagLLMReranker, FlagReranker + +from rank_llm.data import Result +from rank_llm.rerank.pointwise.pointwise_rankllm import PointwiseRankLLM + +logger = logging.getLogger(__name__).setLevel(logging.WARNING) + + +class BGE_RERANKER_V2(PointwiseRankLLM): + def __init__( + self, + model: str, + prompt_mode: str = "bge-reranker-v2", + context_size: int = 8192, + device: str = "cuda", + batch_size: int = 32, + use_fp16: bool = False + ): + super().__init__( + model=model, + context_size=context_size, + prompt_mode=prompt_mode, + device=device, + batch_size=batch_size, + ) + + if "base" in self._model or "large" in self._model or "m3" in self._model: + self._llm = FlagReranker(self._model, use_fp16=use_fp16) + elif "minicpm-layerwise" in self._model: + self._llm=LayerWiseFlagLLMReranker(self._model, use_fp16=use_fp16) + elif "gemma" in self._model: + self._llm=FlagLLMReranker(self._model, use_fp16=use_fp16) + else: + raise ValueError("Given bge model doesn't exist or isn't supported in rank_llm.") + + def run_llm_batched( + self, + prompts: List[List[str]], + ) -> Tuple[None, None, List[float]]: + + all_outputs = None + all_output_token_counts = None + all_scores = [] + + pairs = prompts + + with torch.no_grad(): + if "base" in self._model or "large" in self._model or "m3" in self._model: + all_scores = self._llm.compute_score(pairs) + + elif "gemma" in self._model: + all_scores = self._llm.compute_score(pairs) + + elif "minicpm-layerwise" in self._model: + scores = self._llm.compute_score(pairs, cutoff_layers=[28]) + if not isinstance(scores[0], float): + for score in scores[0]: + all_scores.append(score) + else: + all_scores = scores + + return all_outputs, all_output_token_counts, all_scores + + def run_llm(self, prompt: str) -> Tuple[None, None, float]: + + outputs = None + output_ids = None + + pair=prompt + + with torch.no_grad(): + if "base" in self._model or "large" in self._model or "m3" in self._model: + score = self._llm.compute_score(pair) + + elif "gemma" in self._model: + score = self._llm.compute_score(pair) + + elif "minicpm-layerwise" in self._model: + score = self._llm.compute_score(pair, cutoff_layers=[28]) + + return outputs, output_ids, score + + def num_output_tokens(self) -> int: + return 1 + + def create_prompt(self, result: Result, index: int) -> Tuple[str, int]: + query = result.query.text + prompt = [query, self.convert_doc_to_prompt_content(result.candidates[index].doc, max_length=self._context_size)] + + return prompt, None + + def get_num_tokens(self, prompt: str) -> int: + return 1 + + def cost_per_1k_token(self, input_token: bool) -> float: + return 0 diff --git a/src/rank_llm/rerank/pointwise/monot5.py b/src/rank_llm/rerank/pointwise/monot5.py index 6c1c5556..4619aa23 100644 --- a/src/rank_llm/rerank/pointwise/monot5.py +++ b/src/rank_llm/rerank/pointwise/monot5.py @@ -1,8 +1,8 @@ import logging import math from typing import List, Tuple - -from transformers import T5ForConditionalGeneration, T5Tokenizer +import torch +from transformers import T5ForConditionalGeneration, T5Tokenizer, MT5Tokenizer, MT5ForConditionalGeneration from transformers.generation import GenerationConfig from rank_llm.data import Result @@ -10,6 +10,29 @@ logger = logging.getLogger(__name__) +PREDICTION_TOKENS = { + 'castorini/monot5-base-msmarco': ['▁false', '▁true'], + 'castorini/monot5-base-msmarco-10k': ['▁false', '▁true'], + 'castorini/monot5-large-msmarco': ['▁false', '▁true'], + 'castorini/monot5-large-msmarco-10k': ['▁false', '▁true'], + 'castorini/monot5-base-med-msmarco': ['▁false', '▁true'], + 'castorini/monot5-3b-med-msmarco': ['▁false', '▁true'], + 'castorini/monot5-3b-msmarco-10k': ['▁false', '▁true'], + 'unicamp-dl/mt5-base-en-msmarco': ['▁no' , '▁yes'], + 'unicamp-dl/ptt5-base-pt-msmarco-10k-v2': ['▁não' , '▁sim'], + 'unicamp-dl/ptt5-base-pt-msmarco-100k-v2': ['▁não' , '▁sim'], + 'unicamp-dl/ptt5-base-en-pt-msmarco-100k-v2':['▁não' , '▁sim'], + 'unicamp-dl/mt5-base-en-pt-msmarco-v2': ['▁no' , '▁yes'], + 'unicamp-dl/mt5-base-mmarco-v2': ['▁no' , '▁yes'], + 'unicamp-dl/mt5-base-en-pt-msmarco-v1': ['▁no' , '▁yes'], + 'unicamp-dl/mt5-base-mmarco-v1': ['▁no' , '▁yes'], + 'unicamp-dl/ptt5-base-pt-msmarco-10k-v1': ['▁não' , '▁sim'], + 'unicamp-dl/ptt5-base-pt-msmarco-100k-v1': ['▁não' , '▁sim'], + 'unicamp-dl/ptt5-base-en-pt-msmarco-10k-v1': ['▁não' , '▁sim'], + 'unicamp-dl/mt5-3b-mmarco-en-pt': ['▁' , '▁true'], + 'unicamp-dl/mt5-13b-mmarco-100k': ['▁', '▁true'], + } + class MonoT5(PointwiseRankLLM): def __init__( @@ -19,6 +42,7 @@ def __init__( context_size: int = 512, device: str = "cuda", batch_size: int = 32, + dtype: str = "float16" ): super().__init__( model=model, @@ -28,10 +52,26 @@ def __init__( batch_size=batch_size, ) - self._tokenizer = T5Tokenizer.from_pretrained(model) - self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device) + if model.find("mt5") != -1: + self._tokenizer = MT5Tokenizer.from_pretrained(model) + dtype = torch.float16 if dtype == "float16" else torch.float32 + self._llm = MT5ForConditionalGeneration.from_pretrained(model, torch_dtype=dtype).to(self._device) + elif model.find("monot5") != -1: + self._tokenizer = T5Tokenizer.from_pretrained(model) + self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device) self._context_size = context_size + @staticmethod + def get_prediction_tokens(pretrained_model_name_or_path: str, tokenizer): + if pretrained_model_name_or_path in PREDICTION_TOKENS: + token_false, token_true = PREDICTION_TOKENS[pretrained_model_name_or_path] + token_false_id = tokenizer.get_vocab()[token_false] + token_true_id = tokenizer.get_vocab()[token_true] + return token_false_id, token_true_id + else: + raise Exception(f"We don't know the indexes for the non-relevant/relevant tokens for\ + the checkpoint {pretrained_model_name_or_path} and you did not provide any.") + pass def run_llm_batched( self, prompts: List[str], @@ -70,10 +110,15 @@ def run_llm_batched( ] for logit_tensor in batch_logits[0]: - truth_logit = logit_tensor[1176] - false_logit = logit_tensor[6136] - score = math.exp(truth_logit) / ( - math.exp(truth_logit) + math.exp(false_logit) + token_false_id, token_true_id = self.get_prediction_tokens(self._model, self._tokenizer) + pos_index = logit_tensor[token_true_id] + neg_index = logit_tensor[token_false_id] + #wanted_logit = logit_tensor.find("") + #print(f"current False_logit is {neg_index}") + #print(f"wanted false logit can be found at {wanted_logit}") + + score = math.exp(pos_index) / ( + math.exp(pos_index) + math.exp(neg_index) ) all_scores.append(score) all_output_token_counts.append(self.num_output_tokens) @@ -82,6 +127,8 @@ def run_llm_batched( return all_outputs, all_output_token_counts, all_scores + + def run_llm(self, prompt: str) -> Tuple[str, int, float]: gen_cfg = GenerationConfig.from_model_config(self._llm.config) gen_cfg.max_new_tokens = self.num_output_tokens() diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index 7df7b1c8..5a34629e 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -15,6 +15,7 @@ class PromptMode(Enum): LRL = "LRL" MONOT5 = "monot5" LiT5 = "LiT5" + BGE_RERANKER_V2= "bge-reranker-v2" def __str__(self): return self.value diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 59b32ef2..b3ac6d49 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -11,6 +11,7 @@ from rank_llm.rerank.listwise import RankListwiseOSLLM, SafeOpenai from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore from rank_llm.rerank.pointwise.monot5 import MonoT5 +from rank_llm.rerank.pointwise.bge_reranker_v2 import BGE_RERANKER_V2 from rank_llm.rerank.rankllm import RankLLM @@ -279,6 +280,30 @@ def create_agent( batch_size=batch_size, ) + elif "mt5" in model_path: + # using monot5 + print(f"Loading {model_path} ...") + + keys_and_defaults = [ + ("prompt_mode", PromptMode.MONOT5), + ("context_size", 512), + ("device", "cuda"), + ("batch_size", 64), + ("dtype", "float32"), + ] + [prompt_mode, context_size, device, batch_size, dtype] = extract_kwargs( + keys_and_defaults, **kwargs + ) + + agent = MonoT5( + model=model_path, + prompt_mode=prompt_mode, + context_size=context_size, + device=device, + batch_size=batch_size, + dtype=dtype + ) + elif "lit5-distill" in model_path.lower(): keys_and_defaults = [ ("context_size", 150), @@ -345,6 +370,36 @@ def create_agent( batched=vllm_batched, ) print(f"Completed loading {model_path}") + + elif "bge-reranker" in model_path.lower(): + print(f"Loading {model_path} ...") + + keys_and_defaults=[ + ("device", "cuda"), + ("use_fp16", True), + ("prompt_mode", PromptMode.BGE_RERANKER_V2), + ("context_size", 8192), + ("batch_size", 64) + ] + ( + device, + use_fp16, + prompt_mode, + context_size, + batch_size + ) = extract_kwargs(keys_and_defaults, **kwargs) + + agent = BGE_RERANKER_V2( + model=model_path, + prompt_mode=prompt_mode, + context_size=context_size, + use_fp16=use_fp16, + batch_size=batch_size, + device=device + ) + + print(f"Completed loading {model_path}") + elif model_path in ["unspecified", "rank_random", "rank_identity"]: # NULL reranker agent = None diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 3639525b..f3fa8a22 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -38,6 +38,7 @@ def main(args): window_size = args.window_size system_message = args.system_message vllm_batched = args.vllm_batched + dtype = args.dtype _ = retrieve_and_rerank( model_path=model_path, @@ -62,6 +63,7 @@ def main(args): step_size=step_size, system_message=system_message, vllm_batched=vllm_batched, + dtype=dtype ) @@ -175,5 +177,11 @@ def main(args): action="store_true", help="whether to run the model in batches", ) + parser.add_argument( + "--dtype", + default="float16", + type=str, + help="determine which dtype to do inference with, either float16 or float32", + ) args = parser.parse_args() main(args)