11from __future__ import annotations
22
33import logging
4+ import math
45import os
6+ import queue
57import tempfile
68import traceback
79from collections .abc import Callable
10+ from multiprocessing import Queue
811from pathlib import Path
9- from typing import Literal , overload
12+ from typing import Any , Literal , overload
1013
1114import numpy as np
1215import torch
16+ import torch .multiprocessing as mp
1317from huggingface_hub import HfApi
1418from packaging import version
1519from torch import nn
2125 PretrainedConfig ,
2226 PreTrainedModel ,
2327 PreTrainedTokenizer ,
28+ is_torch_npu_available ,
2429)
2530from transformers .utils import PushToHubMixin
2631from typing_extensions import deprecated
@@ -265,6 +270,193 @@ def get_backend(self) -> Literal["torch", "onnx", "openvino"]:
265270 """
266271 return self .backend
267272
273+ def start_multi_process_pool (
274+ self , target_devices : list [str ] | None = None
275+ ) -> dict [Literal ["input" , "output" , "processes" ], Any ]:
276+ """
277+ Starts a multi-process pool to process the prediction with several independent processes
278+ via :meth:`CrossEncoder.predict <sentence_transformers.cross_encoder.CrossEncoder.predict>`.
279+
280+ This method is recommended if you want to predict on multiple GPUs or CPUs. It is advised
281+ to start only one process per GPU. This method works together with predict and
282+ stop_multi_process_pool.
283+
284+ Args:
285+ target_devices (List[str], optional): PyTorch target devices, e.g. ["cuda:0", "cuda:1", ...],
286+ ["npu:0", "npu:1", ...], or ["cpu", "cpu", "cpu", "cpu"]. If target_devices is None and CUDA/NPU
287+ is available, then all available CUDA/NPU devices will be used. If target_devices is None and
288+ CUDA/NPU is not available, then 4 CPU devices will be used.
289+
290+ Returns:
291+ Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue.
292+ """
293+ if target_devices is None :
294+ if torch .cuda .is_available ():
295+ target_devices = [f"cuda:{ i } " for i in range (torch .cuda .device_count ())]
296+ elif is_torch_npu_available ():
297+ target_devices = [f"npu:{ i } " for i in range (torch .npu .device_count ())]
298+ else :
299+ logger .info ("CUDA/NPU is not available. Starting 4 CPU workers" )
300+ target_devices = ["cpu" ] * 4
301+
302+ logger .info ("Start multi-process pool on devices: {}" .format (", " .join (map (str , target_devices ))))
303+
304+ self .to ("cpu" )
305+ self .share_memory ()
306+
307+ ctx = mp .get_context ("spawn" )
308+ input_queue = ctx .Queue ()
309+ output_queue = ctx .Queue ()
310+ processes = []
311+
312+ for device_id in target_devices :
313+ p = ctx .Process (
314+ target = self .__class__ ._predict_multi_process_worker ,
315+ args = (device_id , self , input_queue , output_queue ),
316+ daemon = True ,
317+ )
318+ p .start ()
319+ processes .append (p )
320+
321+ return {"input" : input_queue , "output" : output_queue , "processes" : processes }
322+
323+ @staticmethod
324+ def stop_multi_process_pool (pool : dict [Literal ["input" , "output" , "processes" ], Any ]) -> None :
325+ """
326+ Stops all processes started with start_multi_process_pool.
327+
328+ Args:
329+ pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.
330+
331+ Returns:
332+ None
333+ """
334+ for p in pool ["processes" ]:
335+ p .terminate ()
336+
337+ for p in pool ["processes" ]:
338+ p .join ()
339+ p .close ()
340+
341+ pool ["input" ].close ()
342+ pool ["output" ].close ()
343+
344+ def _predict_multi_process (
345+ self ,
346+ sentence_pairs : list [tuple [str , str ]] | list [list [str ]],
347+ show_progress_bar : bool | None = True ,
348+ input_was_singular : bool = False ,
349+ pool : dict [Literal ["input" , "output" , "processes" ], Any ] | None = None ,
350+ device : str | list [str | torch .device ] | None = None ,
351+ chunk_size : int | None = None ,
352+ ** predict_kwargs ,
353+ ):
354+ convert_to_tensor = predict_kwargs .get ("convert_to_tensor" , False )
355+ convert_to_numpy = predict_kwargs .get ("convert_to_numpy" , True )
356+ predict_kwargs ["show_progress_bar" ] = False
357+
358+ created_pool = False
359+ if pool is None and isinstance (device , list ) and len (device ) > 0 :
360+ pool = self .start_multi_process_pool (device )
361+ created_pool = True
362+
363+ # Create a pool if is not provided, but a list of devices is
364+ try :
365+ # Determine chunk size if not provided. As a default, aim for 10 chunks per process, with a maximum of 5000 sentences per chunk.
366+ if chunk_size is None :
367+ chunk_size = min (math .ceil (len (sentence_pairs ) / len (pool ["processes" ]) / 10 ), 5000 )
368+ chunk_size = max (chunk_size , 1 ) # Ensure at least 1
369+
370+ input_queue : torch .multiprocessing .Queue = pool ["input" ]
371+ output_queue : torch .multiprocessing .Queue = pool ["output" ]
372+
373+ # Send inputs to the input queue in chunks
374+ chunk_id = - 1 # We default to -1 to handle empty input gracefully
375+ for chunk_id , chunk_start in enumerate (range (0 , len (sentence_pairs ), chunk_size )):
376+ chunk = sentence_pairs [chunk_start : chunk_start + chunk_size ]
377+ input_queue .put ([chunk_id , chunk , predict_kwargs ])
378+
379+ # Collect results from output queue
380+ output_list = sorted (
381+ [output_queue .get () for _ in trange (chunk_id + 1 , desc = "Chunks" , disable = not show_progress_bar )],
382+ key = lambda x : x [0 ], # Sort by chunk_id
383+ )
384+
385+ # Handle singular input case -> return the first (only) result directly
386+ if input_was_singular and output_list :
387+ return output_list [0 ][1 ][0 ] if len (output_list [0 ][1 ]) > 0 else output_list [0 ][1 ]
388+
389+ # Handle the various output formats: torch tensors, numpy arrays, or
390+ # list of dictionaries, also when empty.
391+ scores = [output [1 ] for output in output_list ]
392+
393+ # Check for errors in results
394+ if any (len (output ) > 2 and output [2 ] is not None for output in output_list ):
395+ # Error occurred in worker
396+ error_output = next (output for output in output_list if len (output ) > 2 and output [2 ])
397+ raise RuntimeError (f"Error in worker process: { error_output [2 ]} " )
398+
399+ if scores :
400+ if isinstance (scores [0 ], torch .Tensor ):
401+ scores = torch .cat (scores )
402+ elif isinstance (scores [0 ], np .ndarray ):
403+ scores = np .concatenate (scores , axis = 0 )
404+ elif isinstance (scores [0 ], list ):
405+ scores = sum (scores , [])
406+ else :
407+ scores = sum (scores , [])
408+
409+ elif convert_to_tensor :
410+ scores = torch .tensor ([], device = self .model .device )
411+ elif convert_to_numpy :
412+ scores = np .array ([])
413+ else :
414+ scores = []
415+ return scores
416+
417+ finally :
418+ # Clean up the pool if we created it
419+ if created_pool :
420+ self .stop_multi_process_pool (pool )
421+
422+ @staticmethod
423+ def _predict_multi_process_worker (
424+ target_device : str ,
425+ model : CrossEncoder ,
426+ input_queue : Queue ,
427+ results_queue : Queue ,
428+ ) -> None :
429+ """
430+ Internal working process to predict sentence pairs in a multi-process setup.
431+
432+ """
433+ while True :
434+ try :
435+ chunk_id , sentence_pairs , kwargs = input_queue .get ()
436+ scores = model .predict (sentence_pairs , device = target_device , ** kwargs )
437+
438+ # If multi-process scores are not on CPUs, move them to CPU, so they can all be concatenated later
439+ if isinstance (scores , torch .Tensor ) and scores .device .type != "cpu" :
440+ scores = scores .cpu ()
441+ elif isinstance (scores , np .ndarray ):
442+ scores = np .asarray (scores )
443+ elif isinstance (scores , list ):
444+ scores = [
445+ score .cpu () if isinstance (score , torch .Tensor ) and score .device .type != "cpu" else score
446+ for score in scores
447+ ]
448+ results_queue .put ([chunk_id , scores ])
449+
450+ except queue .Empty :
451+ break
452+ except Exception as e :
453+ logger .error (f"Error in worker process on { target_device } : { e } " )
454+ try :
455+ results_queue .put ([chunk_id , None , str (e )])
456+ except Exception :
457+ pass
458+ break
459+
268460 def set_activation_fn (self , activation_fn : Callable | None , set_default : bool = True ) -> None :
269461 if activation_fn is not None :
270462 self .activation_fn = activation_fn
@@ -401,6 +593,9 @@ def predict(
401593 apply_softmax : bool | None = False ,
402594 convert_to_numpy : bool = True ,
403595 convert_to_tensor : bool = False ,
596+ pool : dict [Literal ["input" , "output" , "processes" ], Any ] | None = None ,
597+ device : str | list [str | torch .device ] | None = None ,
598+ chunk_size : int | None = None ,
404599 ) -> list [torch .Tensor ] | np .ndarray | torch .Tensor :
405600 """
406601 Performs predictions with the CrossEncoder on the given sentence pairs.
@@ -420,6 +615,14 @@ def predict(
420615 a list of PyTorch tensors. Defaults to True.
421616 convert_to_tensor (bool, optional): Whether the output should be one large tensor. Overwrites `convert_to_numpy`.
422617 Defaults to False.
618+ pool (Dict[str, Any], optional): A pool of workers created with :meth:`start_multi_process_pool`. If provided,
619+ multiprocessing will be used. If None and ``device`` is a list, a pool will be created automatically.
620+ Defaults to None.
621+ device (Union[str, List[str]], optional): Device(s) to use for computation. Can be a single device string
622+ (e.g., "cuda:0", "cpu") or a list of devices (e.g., ["cuda:0", "cuda:1"]). If a list is provided,
623+ multiprocessing will be used automatically. Defaults to None.
624+ chunk_size (int, optional): Size of chunks for multiprocessing. If None, a sensible default is calculated.
625+ Only used when ``pool`` is not None or ``device`` is a list. Defaults to None.
423626
424627 Returns:
425628 Union[List[torch.Tensor], np.ndarray, torch.Tensor]: Predictions for the passed sentence pairs.
@@ -437,13 +640,37 @@ def predict(
437640 sentences = [["I love cats", "Cats are amazing"], ["I prefer dogs", "Dogs are loyal"]]
438641 model.predict(sentences)
439642 # => array([0.6912767, 0.4303499], dtype=float32)
643+
644+ # Using multiprocessing with automatic pool
645+ scores = model.predict(sentences, device=["cuda:0", "cuda:1"])
646+
647+ # Using multiprocessing with manual pool
648+ pool = model.start_multi_process_pool()
649+ scores = model.predict(sentences, pool=pool)
650+ model.stop_multi_process_pool(pool)
440651 """
441652 # Cast an individual pair to a list with length 1
442653 input_was_singular = False
443654 if sentences and isinstance (sentences , (list , tuple )) and isinstance (sentences [0 ], str ):
444655 sentences = [sentences ]
445656 input_was_singular = True
446657
658+ # If pool or a list of devices is provided, use multi-process prediction
659+ if pool is not None or (isinstance (device , list ) and len (device ) > 0 ):
660+ return self ._predict_multi_process (
661+ sentence_pairs = sentences ,
662+ show_progress_bar = show_progress_bar ,
663+ input_was_singular = input_was_singular ,
664+ pool = pool ,
665+ device = device ,
666+ chunk_size = chunk_size ,
667+ batch_size = batch_size ,
668+ activation_fn = activation_fn ,
669+ apply_softmax = apply_softmax ,
670+ convert_to_numpy = convert_to_numpy ,
671+ convert_to_tensor = convert_to_tensor ,
672+ )
673+
447674 if show_progress_bar is None :
448675 show_progress_bar = (
449676 logger .getEffectiveLevel () == logging .INFO or logger .getEffectiveLevel () == logging .DEBUG
@@ -452,6 +679,11 @@ def predict(
452679 if activation_fn is not None :
453680 self .set_activation_fn (activation_fn , set_default = False )
454681
682+ target_device = self .model .device
683+ if device is not None and isinstance (device , str ):
684+ target_device = device
685+ self .to (device )
686+
455687 pred_scores = []
456688 self .eval ()
457689 for start_index in trange (0 , len (sentences ), batch_size , desc = "Batches" , disable = not show_progress_bar ):
@@ -462,7 +694,7 @@ def predict(
462694 truncation = True ,
463695 return_tensors = "pt" ,
464696 )
465- features .to (self . model . device )
697+ features .to (target_device )
466698 model_predictions = self .model (** features , return_dict = True )
467699 logits = self .activation_fn (model_predictions .logits )
468700
@@ -477,7 +709,7 @@ def predict(
477709 if len (pred_scores ):
478710 pred_scores = torch .stack (pred_scores )
479711 else :
480- pred_scores = torch .tensor ([], device = self . model . device )
712+ pred_scores = torch .tensor ([], device = target_device )
481713 elif convert_to_numpy :
482714 pred_scores = np .asarray ([score .cpu ().detach ().float ().numpy () for score in pred_scores ])
483715
@@ -499,6 +731,9 @@ def rank(
499731 apply_softmax = False ,
500732 convert_to_numpy : bool = True ,
501733 convert_to_tensor : bool = False ,
734+ pool : dict [Literal ["input" , "output" , "processes" ], Any ] | None = None ,
735+ device : str | list [str | torch .device ] | None = None ,
736+ chunk_size : int | None = None ,
502737 ) -> list [dict [Literal ["corpus_id" , "score" , "text" ], int | float | str ]]:
503738 """
504739 Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores.
@@ -514,6 +749,14 @@ def rank(
514749 convert_to_numpy (bool, optional): Convert the output to a numpy matrix. Defaults to True.
515750 apply_softmax (bool, optional): If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output. Defaults to False.
516751 convert_to_tensor (bool, optional): Convert the output to a tensor. Defaults to False.
752+ pool (Dict[str, Any], optional): A pool of workers created with :meth:`start_multi_process_pool`. If provided,
753+ multiprocessing will be used. If None and ``device`` is a list, a pool will be created automatically.
754+ Defaults to None.
755+ device (Union[str, List[str]], optional): Device(s) to use for computation. Can be a single device string
756+ (e.g., "cuda:0", "cpu") or a list of devices (e.g., ["cuda:0", "cuda:1"]). If a list is provided,
757+ multiprocessing will be used automatically. Defaults to None.
758+ chunk_size (int, optional): Size of chunks for multiprocessing. If None, a sensible default is calculated.
759+ Only used when ``pool`` is not None or ``device`` is a list. Defaults to None.
517760
518761 Returns:
519762 List[Dict[Literal["corpus_id", "score", "text"], Union[int, float, str]]]: A sorted list with the "corpus_id", "score", and optionally "text" of the documents.
@@ -568,6 +811,9 @@ def rank(
568811 apply_softmax = apply_softmax ,
569812 convert_to_numpy = convert_to_numpy ,
570813 convert_to_tensor = convert_to_tensor ,
814+ pool = pool ,
815+ device = device ,
816+ chunk_size = chunk_size ,
571817 )
572818
573819 results = []
0 commit comments