diff --git a/README.md b/README.md index e6fb882f..ede651e0 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,22 @@ conda install -c conda-forge openjdk=21 maven -y pip install "rank-llm[pyserini]" ``` +## Core Optional Dependencies + +### Install vLLM Support (Optional) +For local model inference using vLLM: +```bash +pip install -e .[vllm] # local installation for development +pip install rank-llm[vllm] # or pip installation +``` + +### Install Transformers Support (Optional) +For T5-based models (DuoT5, MonoT5, etc.): +```bash +pip install -e .[transformers] # local installation for development +pip install rank-llm[transformers] # or pip installation +``` + ## Install [all] Dependencies ```bash pip install -e .[all] # local installation for development diff --git a/docs/optional_dependencies.md b/docs/optional_dependencies.md new file mode 100644 index 00000000..ad318a7f --- /dev/null +++ b/docs/optional_dependencies.md @@ -0,0 +1,82 @@ +# Optional Dependencies in rank_llm + +rank_llm now supports optional dependencies to allow for lighter installations based on your specific needs. + +## Installation Options + +### Core Installation +Install just the core functionality without heavy ML libraries: +```bash +pip install rank_llm +``` + +This includes the basic ranking functionality and API integrations, but excludes: +- vLLM for local model inference +- transformers library for T5-based models + +### With vLLM Support +For local model inference using vLLM: +```bash +pip install rank_llm[vllm] +``` + +This enables: +- `VllmHandler` for local model inference +- `RankListwiseOSLLM` and other vLLM-dependent models + +### With Transformers Support +For T5-based models and transformers functionality: +```bash +pip install rank_llm[transformers] +``` + +This enables: +- `DuoT5` pairwise ranking model +- `MonoT5` pointwise ranking model +- `RankFiD` listwise models +- All T5-based inference handlers + +### Full Installation +Install all optional dependencies: +```bash +pip install rank_llm[all] +``` + +This includes vLLM, transformers, and all other optional features. + +## Error Handling + +When you try to use functionality that requires missing optional dependencies, you'll get helpful error messages: + +```python +from rank_llm.rerank.vllm_handler import VllmHandler + +# If vLLM is not installed: +handler = VllmHandler(...) +# ImportError: vLLM is not installed. Please install it with: pip install rank_llm[vllm] +``` + +```python +from rank_llm.rerank.pairwise.duot5 import DuoT5 + +# If transformers is not installed: +model = DuoT5(...) +# ImportError: transformers is not installed. Please install it with: pip install rank_llm[transformers] +``` + +## Migration Guide + +If you were previously using rank_llm and now get import errors, you likely need to install the optional dependencies: + +1. **For vLLM users**: Run `pip install rank_llm[vllm]` +2. **For T5 model users**: Run `pip install rank_llm[transformers]` +3. **For both**: Run `pip install rank_llm[all]` + +## Development + +When developing rank_llm, install all dependencies: +```bash +pip install -e .[all] +``` + +This ensures you can test all functionality locally. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 25ea9161..e10ebd09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,12 @@ requires-python = ">= 3.11" dependencies = {file = ["requirements.txt"]} [project.optional-dependencies] +vllm = [ + "vllm>=0.4.0" +] +transformers = [ + "transformers>=4.40.1" +] sglang = [ "sglang[all]~=0.4.0" ] @@ -50,7 +56,9 @@ training = [ ] all = [ "rank-llm[genai]", - "rank-llm[pyserini]" + "rank-llm[pyserini]", + "rank-llm[vllm]", + "rank-llm[transformers]" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index c8d631bc..a8a18997 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ tqdm>=4.66.2 openai>=1.23.6 tiktoken>=0.6.0 -transformers>=4.40.1 python-dotenv>=1.0.1 faiss-cpu>=1.8.0 ftfy>=6.2.0 dacite>=1.8.1 -vllm>=0.4.0 pandas>=1.4.0 \ No newline at end of file diff --git a/src/rank_llm/rerank/listwise/lit5/model.py b/src/rank_llm/rerank/listwise/lit5/model.py index a5b4b675..ee746225 100644 --- a/src/rank_llm/rerank/listwise/lit5/model.py +++ b/src/rank_llm/rerank/listwise/lit5/model.py @@ -3,7 +3,14 @@ import torch from torch import nn -from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration, T5Stack + +try: + from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration, T5Stack + TRANSFORMERS_AVAILABLE = True +except ImportError: + raise ImportError( + "transformers is required for the lit5 model. Please install it with: pip install rank_llm[transformers]" + ) from .modeling_t5 import ( T5ForConditionalGeneration as T5ConditionalGenerationCrossAttentionScore, diff --git a/src/rank_llm/rerank/listwise/lit5/modeling_t5.py b/src/rank_llm/rerank/listwise/lit5/modeling_t5.py index c4d344f5..4322abe2 100644 --- a/src/rank_llm/rerank/listwise/lit5/modeling_t5.py +++ b/src/rank_llm/rerank/listwise/lit5/modeling_t5.py @@ -24,30 +24,37 @@ from torch import nn from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import checkpoint -from transformers.activations import ACT2FN -from transformers.file_utils import ( - DUMMY_INPUTS, - DUMMY_MASK, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_torch_fx_proxy, - replace_return_docstrings, -) -from transformers.generation import GenerationMixin -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, - Seq2SeqModelOutput, -) -from transformers.modeling_utils import ( - PreTrainedModel, - find_pruneable_heads_and_indices, - prune_linear_layer, -) -from transformers.models.t5.configuration_t5 import T5Config -from transformers.utils import logging -from transformers.utils.model_parallel_utils import assert_device_map, get_device_map + +try: + from transformers.activations import ACT2FN + from transformers.file_utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + replace_return_docstrings, + ) + from transformers.generation import GenerationMixin + from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + ) + from transformers.modeling_utils import ( + PreTrainedModel, + find_pruneable_heads_and_indices, + prune_linear_layer, + ) + from transformers.models.t5.configuration_t5 import T5Config + from transformers.utils import logging + from transformers.utils.model_parallel_utils import assert_device_map, get_device_map + TRANSFORMERS_AVAILABLE = True +except ImportError: + raise ImportError( + "transformers is required for the lit5 model. Please install it with: pip install rank_llm[transformers]" + ) logger = logging.get_logger(__name__) diff --git a/src/rank_llm/rerank/listwise/rank_fid.py b/src/rank_llm/rerank/listwise/rank_fid.py index e99a05e2..b3e32065 100644 --- a/src/rank_llm/rerank/listwise/rank_fid.py +++ b/src/rank_llm/rerank/listwise/rank_fid.py @@ -3,7 +3,13 @@ import torch from tqdm import tqdm -from transformers import T5Tokenizer + +try: + from transformers import T5Tokenizer + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + T5Tokenizer = None from rank_llm.data import Request, Result from rank_llm.rerank.listwise.listwise_rankllm import ListwiseRankLLM @@ -45,6 +51,11 @@ def __init__( """ Creates instance of the RankFiDDistill class, a specialized version of RankLLM designed from Lit5-Distill. """ + if not TRANSFORMERS_AVAILABLE: + raise ImportError( + "transformers is not installed. Please install it with: pip install rank_llm[transformers]" + ) + super().__init__( model=model, context_size=context_size, @@ -269,6 +280,11 @@ def __init__( precision: str = "bfloat16", device: str = "cuda", ) -> None: + if not TRANSFORMERS_AVAILABLE: + raise ImportError( + "transformers is not installed. Please install it with: pip install rank_llm[transformers]" + ) + super().__init__( model=model, context_size=context_size, diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index f04f6ae2..b37aeeb1 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -4,13 +4,19 @@ import unicodedata from concurrent.futures import ThreadPoolExecutor from importlib.resources import files -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch -import vllm from ftfy import fix_text from tqdm import tqdm +try: + import vllm + VLLM_AVAILABLE = True +except ImportError: + VLLM_AVAILABLE = False + vllm = None + from rank_llm.data import Request, Result from rank_llm.rerank.rankllm import PromptMode from rank_llm.rerank.vllm_handler import VllmHandler @@ -225,7 +231,7 @@ def _evaluate_logits( def _get_logits_single_digit( self, - output: vllm.RequestOutput, + output: Union["vllm.RequestOutput", Any], effective_location: int = 1, total: Tuple[int, int] = (1, 9), ): diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py index 192a1ff5..1964d4dd 100644 --- a/src/rank_llm/rerank/pairwise/duot5.py +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -4,8 +4,15 @@ from importlib.resources import files from typing import List, Optional, Tuple -from transformers import T5ForConditionalGeneration, T5Tokenizer -from transformers.generation import GenerationConfig +try: + from transformers import T5ForConditionalGeneration, T5Tokenizer + from transformers.generation import GenerationConfig + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + T5ForConditionalGeneration = None + T5Tokenizer = None + GenerationConfig = None from rank_llm.data import Result from rank_llm.rerank.pairwise.pairwise_rankllm import PairwiseRankLLM @@ -28,6 +35,11 @@ def __init__( device: str = "cuda", batch_size: int = 32, ): + if not TRANSFORMERS_AVAILABLE: + raise ImportError( + "transformers is not installed. Please install it with: pip install rank_llm[transformers]" + ) + super().__init__( model=model, context_size=context_size, diff --git a/src/rank_llm/rerank/pairwise/pairwise_inference_handler.py b/src/rank_llm/rerank/pairwise/pairwise_inference_handler.py index 2ff9f6e0..1fa7a0c8 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_inference_handler.py +++ b/src/rank_llm/rerank/pairwise/pairwise_inference_handler.py @@ -1,7 +1,12 @@ import re -from typing import Any, Dict, List +from typing import Any, Dict, List, Union -from transformers import T5Tokenizer +try: + from transformers import T5Tokenizer + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + T5Tokenizer = None from rank_llm.data import Result, TemplateSectionConfig from rank_llm.rerank.inference_handler import BaseInferenceHandler @@ -96,7 +101,7 @@ def _generate_body( index1: int, index2: int, single_doc_max_token: int, - tokenizer: T5Tokenizer, + tokenizer: Union["T5Tokenizer", Any], ) -> str: doc1_raw = self._convert_doc_to_prompt_content( result.candidates[index1].doc, max_length=single_doc_max_token diff --git a/src/rank_llm/rerank/pointwise/monot5.py b/src/rank_llm/rerank/pointwise/monot5.py index 0a1c5bda..970bc159 100644 --- a/src/rank_llm/rerank/pointwise/monot5.py +++ b/src/rank_llm/rerank/pointwise/monot5.py @@ -3,8 +3,15 @@ from importlib.resources import files from typing import List, Optional, Tuple -from transformers import T5ForConditionalGeneration, T5Tokenizer -from transformers.generation import GenerationConfig +try: + from transformers import T5ForConditionalGeneration, T5Tokenizer + from transformers.generation import GenerationConfig + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + T5ForConditionalGeneration = None + T5Tokenizer = None + GenerationConfig = None from rank_llm.data import Result from rank_llm.rerank.pointwise.pointwise_rankllm import PointwiseRankLLM @@ -27,6 +34,11 @@ def __init__( device: str = "cuda", batch_size: int = 32, ): + if not TRANSFORMERS_AVAILABLE: + raise ImportError( + "transformers is not installed. Please install it with: pip install rank_llm[transformers]" + ) + super().__init__( model=model, context_size=context_size, diff --git a/src/rank_llm/rerank/pointwise/pointwise_inference_handler.py b/src/rank_llm/rerank/pointwise/pointwise_inference_handler.py index 161df9fd..1fab1586 100644 --- a/src/rank_llm/rerank/pointwise/pointwise_inference_handler.py +++ b/src/rank_llm/rerank/pointwise/pointwise_inference_handler.py @@ -1,7 +1,12 @@ import re -from typing import Any, Dict, List +from typing import Any, Dict, List, Union -from transformers import T5Tokenizer +try: + from transformers import T5Tokenizer + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + T5Tokenizer = None from rank_llm.data import Result, TemplateSectionConfig from rank_llm.rerank.inference_handler import BaseInferenceHandler @@ -86,7 +91,7 @@ def _generate_body( result: Result, index: int, max_doc_tokens: int, - tokenizer: T5Tokenizer, + tokenizer: Union["T5Tokenizer", Any], ) -> str: query = self._replace_number(result.query.text) doc_raw = self._convert_doc_to_prompt_content( diff --git a/src/rank_llm/rerank/vllm_handler.py b/src/rank_llm/rerank/vllm_handler.py index 89c3cf9d..b3c65af9 100644 --- a/src/rank_llm/rerank/vllm_handler.py +++ b/src/rank_llm/rerank/vllm_handler.py @@ -1,8 +1,20 @@ from typing import Any, Dict, List, Optional -import vllm -from transformers import PreTrainedTokenizerBase -from vllm.outputs import RequestOutput +try: + import vllm + from vllm.outputs import RequestOutput + VLLM_AVAILABLE = True +except ImportError: + VLLM_AVAILABLE = False + vllm = None + RequestOutput = None + +try: + from transformers import PreTrainedTokenizerBase + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + PreTrainedTokenizerBase = None class VllmHandler: @@ -16,6 +28,11 @@ def __init__( gpu_memory_utilization: float, **kwargs: Any, ): + if not VLLM_AVAILABLE: + raise ImportError( + "vLLM is not installed. Please install it with: pip install rank_llm[vllm]" + ) + self._vllm = vllm.LLM( model=model, download_dir=download_dir, @@ -35,6 +52,10 @@ def __init__( ) def get_tokenizer(self) -> PreTrainedTokenizerBase: + if not TRANSFORMERS_AVAILABLE: + raise ImportError( + "transformers is not installed. Please install it with: pip install rank_llm[transformers]" + ) return self._tokenizer def generate_output( @@ -46,6 +67,11 @@ def generate_output( logprobs: Optional[int] = None, **kwargs: Any, ) -> List[RequestOutput]: + if not VLLM_AVAILABLE: + raise ImportError( + "vLLM is not installed. Please install it with: pip install rank_llm[vllm]" + ) + # TODO: Implement rest of vllm arguments (from kwargs) in the future if necessary sampling_params = vllm.SamplingParams( min_tokens=min_tokens, diff --git a/test/test_optional_dependencies.py b/test/test_optional_dependencies.py new file mode 100644 index 00000000..a5e48450 --- /dev/null +++ b/test/test_optional_dependencies.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +""" +Test for optional dependencies functionality. +This test validates that modules handle missing optional dependencies gracefully. +""" + +import unittest +import sys + + +class TestOptionalDependencies(unittest.TestCase): + """Test that optional dependencies are handled correctly.""" + + def test_vllm_optional_import(self): + """Test that vLLM can be optionally imported.""" + # Test the import logic that's used in vllm_handler.py + try: + import vllm + from vllm.outputs import RequestOutput + vllm_available = True + except ImportError: + vllm_available = False + vllm = None + RequestOutput = None + + # This should not raise an error regardless of whether vLLM is installed + self.assertIsInstance(vllm_available, bool) + + if not vllm_available: + self.assertIsNone(vllm) + self.assertIsNone(RequestOutput) + + def test_transformers_optional_import(self): + """Test that transformers can be optionally imported.""" + # Test the import logic used in various model files + try: + from transformers import T5ForConditionalGeneration, T5Tokenizer + from transformers.generation import GenerationConfig + from transformers import PreTrainedTokenizerBase + transformers_available = True + except ImportError: + transformers_available = False + T5ForConditionalGeneration = None + T5Tokenizer = None + GenerationConfig = None + PreTrainedTokenizerBase = None + + # This should not raise an error regardless of whether transformers is installed + self.assertIsInstance(transformers_available, bool) + + if not transformers_available: + self.assertIsNone(T5ForConditionalGeneration) + self.assertIsNone(T5Tokenizer) + self.assertIsNone(GenerationConfig) + self.assertIsNone(PreTrainedTokenizerBase) + + def test_error_messages_contain_install_instructions(self): + """Test that error messages include helpful installation instructions.""" + vllm_error_msg = "vLLM is not installed. Please install it with: pip install rank_llm[vllm]" + transformers_error_msg = "transformers is not installed. Please install it with: pip install rank_llm[transformers]" + + # Verify the error messages contain the expected installation instructions + self.assertIn("pip install rank_llm[vllm]", vllm_error_msg) + self.assertIn("pip install rank_llm[transformers]", transformers_error_msg) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file