Skip to content

Commit 3386129

Browse files
committed
add multiprocessing support for Cross Encoder
1 parent 6aaa53b commit 3386129

File tree

1 file changed

+249
-3
lines changed

1 file changed

+249
-3
lines changed

sentence_transformers/cross_encoder/CrossEncoder.py

Lines changed: 249 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from __future__ import annotations
22

33
import logging
4+
import math
45
import os
6+
import queue
57
import tempfile
68
import traceback
79
from collections.abc import Callable
10+
from multiprocessing import Queue
811
from pathlib import Path
9-
from typing import Literal, overload
12+
from typing import Any, Literal, overload
1013

1114
import numpy as np
1215
import torch
16+
import torch.multiprocessing as mp
1317
from huggingface_hub import HfApi
1418
from packaging import version
1519
from torch import nn
@@ -21,6 +25,7 @@
2125
PretrainedConfig,
2226
PreTrainedModel,
2327
PreTrainedTokenizer,
28+
is_torch_npu_available,
2429
)
2530
from transformers.utils import PushToHubMixin
2631
from typing_extensions import deprecated
@@ -265,6 +270,193 @@ def get_backend(self) -> Literal["torch", "onnx", "openvino"]:
265270
"""
266271
return self.backend
267272

273+
def start_multi_process_pool(
274+
self, target_devices: list[str] | None = None
275+
) -> dict[Literal["input", "output", "processes"], Any]:
276+
"""
277+
Starts a multi-process pool to process the prediction with several independent processes
278+
via :meth:`CrossEncoder.predict <sentence_transformers.cross_encoder.CrossEncoder.predict>`.
279+
280+
This method is recommended if you want to predict on multiple GPUs or CPUs. It is advised
281+
to start only one process per GPU. This method works together with predict and
282+
stop_multi_process_pool.
283+
284+
Args:
285+
target_devices (List[str], optional): PyTorch target devices, e.g. ["cuda:0", "cuda:1", ...],
286+
["npu:0", "npu:1", ...], or ["cpu", "cpu", "cpu", "cpu"]. If target_devices is None and CUDA/NPU
287+
is available, then all available CUDA/NPU devices will be used. If target_devices is None and
288+
CUDA/NPU is not available, then 4 CPU devices will be used.
289+
290+
Returns:
291+
Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue.
292+
"""
293+
if target_devices is None:
294+
if torch.cuda.is_available():
295+
target_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
296+
elif is_torch_npu_available():
297+
target_devices = [f"npu:{i}" for i in range(torch.npu.device_count())]
298+
else:
299+
logger.info("CUDA/NPU is not available. Starting 4 CPU workers")
300+
target_devices = ["cpu"] * 4
301+
302+
logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices))))
303+
304+
self.to("cpu")
305+
self.share_memory()
306+
307+
ctx = mp.get_context("spawn")
308+
input_queue = ctx.Queue()
309+
output_queue = ctx.Queue()
310+
processes = []
311+
312+
for device_id in target_devices:
313+
p = ctx.Process(
314+
target=self.__class__._predict_multi_process_worker,
315+
args=(device_id, self, input_queue, output_queue),
316+
daemon=True,
317+
)
318+
p.start()
319+
processes.append(p)
320+
321+
return {"input": input_queue, "output": output_queue, "processes": processes}
322+
323+
@staticmethod
324+
def stop_multi_process_pool(pool: dict[Literal["input", "output", "processes"], Any]) -> None:
325+
"""
326+
Stops all processes started with start_multi_process_pool.
327+
328+
Args:
329+
pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.
330+
331+
Returns:
332+
None
333+
"""
334+
for p in pool["processes"]:
335+
p.terminate()
336+
337+
for p in pool["processes"]:
338+
p.join()
339+
p.close()
340+
341+
pool["input"].close()
342+
pool["output"].close()
343+
344+
def _predict_multi_process(
345+
self,
346+
sentence_pairs: list[tuple[str, str]] | list[list[str]],
347+
show_progress_bar: bool | None = True,
348+
input_was_singular: bool = False,
349+
pool: dict[Literal["input", "output", "processes"], Any] | None = None,
350+
device: str | list[str | torch.device] | None = None,
351+
chunk_size: int | None = None,
352+
**predict_kwargs,
353+
):
354+
convert_to_tensor = predict_kwargs.get("convert_to_tensor", False)
355+
convert_to_numpy = predict_kwargs.get("convert_to_numpy", True)
356+
predict_kwargs["show_progress_bar"] = False
357+
358+
created_pool = False
359+
if pool is None and isinstance(device, list) and len(device) > 0:
360+
pool = self.start_multi_process_pool(device)
361+
created_pool = True
362+
363+
# Create a pool if is not provided, but a list of devices is
364+
try:
365+
# Determine chunk size if not provided. As a default, aim for 10 chunks per process, with a maximum of 5000 sentences per chunk.
366+
if chunk_size is None:
367+
chunk_size = min(math.ceil(len(sentence_pairs) / len(pool["processes"]) / 10), 5000)
368+
chunk_size = max(chunk_size, 1) # Ensure at least 1
369+
370+
input_queue: torch.multiprocessing.Queue = pool["input"]
371+
output_queue: torch.multiprocessing.Queue = pool["output"]
372+
373+
# Send inputs to the input queue in chunks
374+
chunk_id = -1 # We default to -1 to handle empty input gracefully
375+
for chunk_id, chunk_start in enumerate(range(0, len(sentence_pairs), chunk_size)):
376+
chunk = sentence_pairs[chunk_start : chunk_start + chunk_size]
377+
input_queue.put([chunk_id, chunk, predict_kwargs])
378+
379+
# Collect results from output queue
380+
output_list = sorted(
381+
[output_queue.get() for _ in trange(chunk_id + 1, desc="Chunks", disable=not show_progress_bar)],
382+
key=lambda x: x[0], # Sort by chunk_id
383+
)
384+
385+
# Handle singular input case -> return the first (only) result directly
386+
if input_was_singular and output_list:
387+
return output_list[0][1][0] if len(output_list[0][1]) > 0 else output_list[0][1]
388+
389+
# Handle the various output formats: torch tensors, numpy arrays, or
390+
# list of dictionaries, also when empty.
391+
scores = [output[1] for output in output_list]
392+
393+
# Check for errors in results
394+
if any(len(output) > 2 and output[2] is not None for output in output_list):
395+
# Error occurred in worker
396+
error_output = next(output for output in output_list if len(output) > 2 and output[2])
397+
raise RuntimeError(f"Error in worker process: {error_output[2]}")
398+
399+
if scores:
400+
if isinstance(scores[0], torch.Tensor):
401+
scores = torch.cat(scores)
402+
elif isinstance(scores[0], np.ndarray):
403+
scores = np.concatenate(scores, axis=0)
404+
elif isinstance(scores[0], list):
405+
scores = sum(scores, [])
406+
else:
407+
scores = sum(scores, [])
408+
409+
elif convert_to_tensor:
410+
scores = torch.tensor([], device=self.model.device)
411+
elif convert_to_numpy:
412+
scores = np.array([])
413+
else:
414+
scores = []
415+
return scores
416+
417+
finally:
418+
# Clean up the pool if we created it
419+
if created_pool:
420+
self.stop_multi_process_pool(pool)
421+
422+
@staticmethod
423+
def _predict_multi_process_worker(
424+
target_device: str,
425+
model: CrossEncoder,
426+
input_queue: Queue,
427+
results_queue: Queue,
428+
) -> None:
429+
"""
430+
Internal working process to predict sentence pairs in a multi-process setup.
431+
432+
"""
433+
while True:
434+
try:
435+
chunk_id, sentence_pairs, kwargs = input_queue.get()
436+
scores = model.predict(sentence_pairs, device=target_device, **kwargs)
437+
438+
# If multi-process scores are not on CPUs, move them to CPU, so they can all be concatenated later
439+
if isinstance(scores, torch.Tensor) and scores.device.type != "cpu":
440+
scores = scores.cpu()
441+
elif isinstance(scores, np.ndarray):
442+
scores = np.asarray(scores)
443+
elif isinstance(scores, list):
444+
scores = [
445+
score.cpu() if isinstance(score, torch.Tensor) and score.device.type != "cpu" else score
446+
for score in scores
447+
]
448+
results_queue.put([chunk_id, scores])
449+
450+
except queue.Empty:
451+
break
452+
except Exception as e:
453+
logger.error(f"Error in worker process on {target_device}: {e}")
454+
try:
455+
results_queue.put([chunk_id, None, str(e)])
456+
except Exception:
457+
pass
458+
break
459+
268460
def set_activation_fn(self, activation_fn: Callable | None, set_default: bool = True) -> None:
269461
if activation_fn is not None:
270462
self.activation_fn = activation_fn
@@ -401,6 +593,9 @@ def predict(
401593
apply_softmax: bool | None = False,
402594
convert_to_numpy: bool = True,
403595
convert_to_tensor: bool = False,
596+
pool: dict[Literal["input", "output", "processes"], Any] | None = None,
597+
device: str | list[str | torch.device] | None = None,
598+
chunk_size: int | None = None,
404599
) -> list[torch.Tensor] | np.ndarray | torch.Tensor:
405600
"""
406601
Performs predictions with the CrossEncoder on the given sentence pairs.
@@ -420,6 +615,14 @@ def predict(
420615
a list of PyTorch tensors. Defaults to True.
421616
convert_to_tensor (bool, optional): Whether the output should be one large tensor. Overwrites `convert_to_numpy`.
422617
Defaults to False.
618+
pool (Dict[str, Any], optional): A pool of workers created with :meth:`start_multi_process_pool`. If provided,
619+
multiprocessing will be used. If None and ``device`` is a list, a pool will be created automatically.
620+
Defaults to None.
621+
device (Union[str, List[str]], optional): Device(s) to use for computation. Can be a single device string
622+
(e.g., "cuda:0", "cpu") or a list of devices (e.g., ["cuda:0", "cuda:1"]). If a list is provided,
623+
multiprocessing will be used automatically. Defaults to None.
624+
chunk_size (int, optional): Size of chunks for multiprocessing. If None, a sensible default is calculated.
625+
Only used when ``pool`` is not None or ``device`` is a list. Defaults to None.
423626
424627
Returns:
425628
Union[List[torch.Tensor], np.ndarray, torch.Tensor]: Predictions for the passed sentence pairs.
@@ -437,13 +640,37 @@ def predict(
437640
sentences = [["I love cats", "Cats are amazing"], ["I prefer dogs", "Dogs are loyal"]]
438641
model.predict(sentences)
439642
# => array([0.6912767, 0.4303499], dtype=float32)
643+
644+
# Using multiprocessing with automatic pool
645+
scores = model.predict(sentences, device=["cuda:0", "cuda:1"])
646+
647+
# Using multiprocessing with manual pool
648+
pool = model.start_multi_process_pool()
649+
scores = model.predict(sentences, pool=pool)
650+
model.stop_multi_process_pool(pool)
440651
"""
441652
# Cast an individual pair to a list with length 1
442653
input_was_singular = False
443654
if sentences and isinstance(sentences, (list, tuple)) and isinstance(sentences[0], str):
444655
sentences = [sentences]
445656
input_was_singular = True
446657

658+
# If pool or a list of devices is provided, use multi-process prediction
659+
if pool is not None or (isinstance(device, list) and len(device) > 0):
660+
return self._predict_multi_process(
661+
sentence_pairs=sentences,
662+
show_progress_bar=show_progress_bar,
663+
input_was_singular=input_was_singular,
664+
pool=pool,
665+
device=device,
666+
chunk_size=chunk_size,
667+
batch_size=batch_size,
668+
activation_fn=activation_fn,
669+
apply_softmax=apply_softmax,
670+
convert_to_numpy=convert_to_numpy,
671+
convert_to_tensor=convert_to_tensor,
672+
)
673+
447674
if show_progress_bar is None:
448675
show_progress_bar = (
449676
logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG
@@ -452,6 +679,11 @@ def predict(
452679
if activation_fn is not None:
453680
self.set_activation_fn(activation_fn, set_default=False)
454681

682+
target_device = self.model.device
683+
if device is not None and isinstance(device, str):
684+
target_device = device
685+
self.to(device)
686+
455687
pred_scores = []
456688
self.eval()
457689
for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
@@ -462,7 +694,7 @@ def predict(
462694
truncation=True,
463695
return_tensors="pt",
464696
)
465-
features.to(self.model.device)
697+
features.to(target_device)
466698
model_predictions = self.model(**features, return_dict=True)
467699
logits = self.activation_fn(model_predictions.logits)
468700

@@ -477,7 +709,7 @@ def predict(
477709
if len(pred_scores):
478710
pred_scores = torch.stack(pred_scores)
479711
else:
480-
pred_scores = torch.tensor([], device=self.model.device)
712+
pred_scores = torch.tensor([], device=target_device)
481713
elif convert_to_numpy:
482714
pred_scores = np.asarray([score.cpu().detach().float().numpy() for score in pred_scores])
483715

@@ -499,6 +731,9 @@ def rank(
499731
apply_softmax=False,
500732
convert_to_numpy: bool = True,
501733
convert_to_tensor: bool = False,
734+
pool: dict[Literal["input", "output", "processes"], Any] | None = None,
735+
device: str | list[str | torch.device] | None = None,
736+
chunk_size: int | None = None,
502737
) -> list[dict[Literal["corpus_id", "score", "text"], int | float | str]]:
503738
"""
504739
Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores.
@@ -514,6 +749,14 @@ def rank(
514749
convert_to_numpy (bool, optional): Convert the output to a numpy matrix. Defaults to True.
515750
apply_softmax (bool, optional): If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output. Defaults to False.
516751
convert_to_tensor (bool, optional): Convert the output to a tensor. Defaults to False.
752+
pool (Dict[str, Any], optional): A pool of workers created with :meth:`start_multi_process_pool`. If provided,
753+
multiprocessing will be used. If None and ``device`` is a list, a pool will be created automatically.
754+
Defaults to None.
755+
device (Union[str, List[str]], optional): Device(s) to use for computation. Can be a single device string
756+
(e.g., "cuda:0", "cpu") or a list of devices (e.g., ["cuda:0", "cuda:1"]). If a list is provided,
757+
multiprocessing will be used automatically. Defaults to None.
758+
chunk_size (int, optional): Size of chunks for multiprocessing. If None, a sensible default is calculated.
759+
Only used when ``pool`` is not None or ``device`` is a list. Defaults to None.
517760
518761
Returns:
519762
List[Dict[Literal["corpus_id", "score", "text"], Union[int, float, str]]]: A sorted list with the "corpus_id", "score", and optionally "text" of the documents.
@@ -568,6 +811,9 @@ def rank(
568811
apply_softmax=apply_softmax,
569812
convert_to_numpy=convert_to_numpy,
570813
convert_to_tensor=convert_to_tensor,
814+
pool=pool,
815+
device=device,
816+
chunk_size=chunk_size,
571817
)
572818

573819
results = []

0 commit comments

Comments
 (0)