Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions python/cuml/cuml/benchmark/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
_build_gtil_classifier,
_build_mnmg_umap,
_build_optimized_fil_classifier,
_build_xgboost_classifier_for_training,
_build_xgboost_regressor_for_training,
_training_data_to_numpy,
_treelite_fil_accuracy_score,
fit,
fit_kneighbors,
fit_predict,
fit_transform,
predict,
train_xgboost,
transform,
)
from cuml.preprocessing import (
Expand All @@ -60,10 +63,12 @@


try:
import xgboost as xgb
from xgboost import XGBClassifier, XGBRegressor
except ImportError:
XGBClassifier = None
XGBRegressor = None
xgb = None


class AlgorithmPair:
Expand Down Expand Up @@ -365,6 +370,82 @@ def all_algorithms():
accepts_labels=True,
accuracy_function=metrics.r2_score,
),
AlgorithmPair(
xgb,
xgb,
shared_args={
"tree_method": "hist",
"n_estimators": 100,
"use_quantile_dmatrix": False,
"max_bin": 256,
},
cpu_args={"device": "cpu"},
cuml_args={"device": "cuda"},
name="xgboost-native-classification",
accepts_labels=False,
setup_cpu_func=_build_xgboost_classifier_for_training,
setup_cuml_func=_build_xgboost_classifier_for_training,
cpu_data_prep_hook=_labels_to_int_hook,
cuml_data_prep_hook=_labels_to_int_hook,
accuracy_function=metrics.accuracy_score,
bench_func=train_xgboost,
),
AlgorithmPair(
xgb,
xgb,
shared_args={
"tree_method": "hist",
"n_estimators": 100,
"use_quantile_dmatrix": True,
"max_bin": 256,
},
cpu_args={"device": "cpu"},
cuml_args={"device": "cuda"},
name="xgboost-native-classification-quantile",
accepts_labels=False,
setup_cpu_func=_build_xgboost_classifier_for_training,
setup_cuml_func=_build_xgboost_classifier_for_training,
cpu_data_prep_hook=_labels_to_int_hook,
cuml_data_prep_hook=_labels_to_int_hook,
accuracy_function=metrics.accuracy_score,
bench_func=train_xgboost,
),
AlgorithmPair(
xgb,
xgb,
shared_args={
"tree_method": "hist",
"n_estimators": 100,
"use_quantile_dmatrix": False,
"max_bin": 256,
},
cpu_args={"device": "cpu"},
cuml_args={"device": "cuda"},
name="xgboost-native-regression",
accepts_labels=False,
setup_cpu_func=_build_xgboost_regressor_for_training,
setup_cuml_func=_build_xgboost_regressor_for_training,
accuracy_function=metrics.r2_score,
bench_func=train_xgboost,
),
AlgorithmPair(
xgb,
xgb,
shared_args={
"tree_method": "hist",
"n_estimators": 100,
"use_quantile_dmatrix": True,
"max_bin": 256,
},
cpu_args={"device": "cpu"},
cuml_args={"device": "cuda"},
name="xgboost-native-regression-quantile",
accepts_labels=False,
setup_cpu_func=_build_xgboost_regressor_for_training,
setup_cuml_func=_build_xgboost_regressor_for_training,
accuracy_function=metrics.r2_score,
bench_func=train_xgboost,
),
AlgorithmPair(
sklearn.manifold.TSNE,
cuml.manifold.TSNE,
Expand Down
184 changes: 184 additions & 0 deletions python/cuml/cuml/benchmark/bench_helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from cuml.internals.device_type import DeviceType
from cuml.manifold import UMAP

try:
import xgboost as xgb
except ImportError:
xgb = None


def call(m, func_name, X, y=None):
def unwrap_and_get_args(func):
Expand Down Expand Up @@ -82,6 +87,43 @@ def fit_kneighbors(m, x, y=None):
call(m, "fit_kneighbors", x, y)


def train_xgboost(m, x, y=None):
"""
Bench function for XGBoost that times the training phase.

This function is designed to work with XGBoostTrainWrapper instances
that have been pre-configured by the setup functions. It calls the
retrain() method which performs the actual training that gets timed.

Parameters
----------
m : XGBoostTrainWrapper
x : array-like
y : array-like, optional

Returns
-------
booster : xgboost.Booster or None
The trained XGBoost booster model

Notes
-----
The x and y parameters are present for interface consistency with other
benchmark functions but are not used since the training data is already
contained in the pre-built DMatrix within the wrapper. This is because
The benchmarking framework always calls bench_func with
(model, data[0], data[1]) or (model, data[0]).
"""
if hasattr(m, "retrain"):
return m.retrain()
else:
raise ValueError(
f"Expected XGBoostTrainWrapper with 'retrain' method, "
f"but got {type(m).__name__}. Ensure the setup function "
f"(_build_xgboost_*_for_training) was called correctly."
)


def _training_data_to_numpy(X, y):
"""Convert input training data into numpy format"""
if isinstance(X, np.ndarray):
Expand Down Expand Up @@ -372,3 +414,145 @@ def _build_mnmg_umap(m, data, args, tmpdir):
local_model.fit(X)

return m(client=client, model=local_model, **args)


class XGBoostTrainWrapper:
"""Helper class for benchmarking XGBoost training phase"""

def __init__(
self,
dtrain,
params,
num_boost_round,
device="cpu",
):
self.dtrain = dtrain
self.params = params
self.num_boost_round = num_boost_round
self.device = device
self.booster = None

def retrain(self):
"""Retrain the model - this is what gets timed"""

debug_mode = os.environ.get("XGBOOST_DEBUG", "0") == "1"
if debug_mode:
print(
f"[XGBoost Retrain] Device: {self.device}, Rounds: {self.num_boost_round}"
)

self.booster = xgb.train(
self.params, self.dtrain, self.num_boost_round
)
return self.booster


def _build_xgboost_for_training(
m, data, args, tmpdir, task_type="classification"
):
"""
Common setup function for XGBoost training - prepares but doesn't train (for timing training).

Args:
m: Model (unused but required for interface consistency)
data: Training data tuple (features, labels)
args: Configuration arguments
tmpdir: Temporary directory (unused)
task_type: Either "classification" or "regression"

Returns:
XGBoostTrainWrapper ready for training
"""

train_data, train_label = _training_data_to_numpy(data[0], data[1])

args_copy = args.copy()

use_quantile_dmatrix = args_copy.pop("use_quantile_dmatrix", False)
max_bin = args_copy.pop("max_bin", 256)
num_boost_round = args_copy.pop("n_estimators", 100)
device = args_copy.pop("device", "cpu")

debug_mode = os.environ.get("XGBOOST_DEBUG", "0") == "1"
if debug_mode:
task_name = (
"Classifier" if task_type == "classification" else "Regressor"
)
print(f"[XGBoost Setup {task_name} for Training] Device: {device}")

# Task-specific label processing and parameter setup
if task_type == "classification":
unique_labels = np.unique(train_label)
n_classes = len(unique_labels)

label_map = {
old_label: new_label
for new_label, old_label in enumerate(unique_labels)
}
train_label_normalized = np.array(
[label_map[label] for label in train_label]
)

# Determine objective based on number of classes
if n_classes == 2:
objective = "binary:logistic"
eval_metric = "error"
else:
objective = "multi:softmax"
eval_metric = "merror"

params = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max_bin needs to be part of training parameters as well.

"objective": objective,
"eval_metric": eval_metric,
"device": device,
}

if n_classes > 2:
params["num_class"] = n_classes
else:
train_label_normalized = train_label

params = {
"objective": "reg:squarederror",
"eval_metric": "rmse",
"device": device,
}

if use_quantile_dmatrix:
dtrain = xgb.QuantileDMatrix(
train_data, label=train_label_normalized, max_bin=max_bin
)
else:
dtrain = xgb.DMatrix(train_data, label=train_label_normalized)

params.update(args_copy)

if debug_mode:
task_name = (
"Classifier" if task_type == "classification" else "Regressor"
)
print(
f"[XGBoost Setup {task_name} for Training] Final params: {params}"
)

# Return wrapper that can retrain (training will be timed)
return XGBoostTrainWrapper(
dtrain,
params,
num_boost_round,
device=device,
)


def _build_xgboost_classifier_for_training(m, data, args, tmpdir):
"""Setup function for XGBoost classification - prepares but doesn't train (for timing training)"""
return _build_xgboost_for_training(
m, data, args, tmpdir, task_type="classification"
)


def _build_xgboost_regressor_for_training(m, data, args, tmpdir):
"""Setup function for XGBoost regression - prepares but doesn't train (for timing training)"""
return _build_xgboost_for_training(
m, data, args, tmpdir, task_type="regression"
)