diff --git a/python/cuml/cuml/benchmark/algorithms.py b/python/cuml/cuml/benchmark/algorithms.py index 2862a80573..3e6462cb8c 100644 --- a/python/cuml/cuml/benchmark/algorithms.py +++ b/python/cuml/cuml/benchmark/algorithms.py @@ -28,6 +28,8 @@ _build_gtil_classifier, _build_mnmg_umap, _build_optimized_fil_classifier, + _build_xgboost_classifier_for_training, + _build_xgboost_regressor_for_training, _training_data_to_numpy, _treelite_fil_accuracy_score, fit, @@ -35,6 +37,7 @@ fit_predict, fit_transform, predict, + train_xgboost, transform, ) from cuml.preprocessing import ( @@ -60,10 +63,12 @@ try: + import xgboost as xgb from xgboost import XGBClassifier, XGBRegressor except ImportError: XGBClassifier = None XGBRegressor = None + xgb = None class AlgorithmPair: @@ -365,6 +370,82 @@ def all_algorithms(): accepts_labels=True, accuracy_function=metrics.r2_score, ), + AlgorithmPair( + xgb, + xgb, + shared_args={ + "tree_method": "hist", + "n_estimators": 100, + "use_quantile_dmatrix": False, + "max_bin": 256, + }, + cpu_args={"device": "cpu"}, + cuml_args={"device": "cuda"}, + name="xgboost-native-classification", + accepts_labels=False, + setup_cpu_func=_build_xgboost_classifier_for_training, + setup_cuml_func=_build_xgboost_classifier_for_training, + cpu_data_prep_hook=_labels_to_int_hook, + cuml_data_prep_hook=_labels_to_int_hook, + accuracy_function=metrics.accuracy_score, + bench_func=train_xgboost, + ), + AlgorithmPair( + xgb, + xgb, + shared_args={ + "tree_method": "hist", + "n_estimators": 100, + "use_quantile_dmatrix": True, + "max_bin": 256, + }, + cpu_args={"device": "cpu"}, + cuml_args={"device": "cuda"}, + name="xgboost-native-classification-quantile", + accepts_labels=False, + setup_cpu_func=_build_xgboost_classifier_for_training, + setup_cuml_func=_build_xgboost_classifier_for_training, + cpu_data_prep_hook=_labels_to_int_hook, + cuml_data_prep_hook=_labels_to_int_hook, + accuracy_function=metrics.accuracy_score, + bench_func=train_xgboost, + ), + AlgorithmPair( + xgb, + xgb, + shared_args={ + "tree_method": "hist", + "n_estimators": 100, + "use_quantile_dmatrix": False, + "max_bin": 256, + }, + cpu_args={"device": "cpu"}, + cuml_args={"device": "cuda"}, + name="xgboost-native-regression", + accepts_labels=False, + setup_cpu_func=_build_xgboost_regressor_for_training, + setup_cuml_func=_build_xgboost_regressor_for_training, + accuracy_function=metrics.r2_score, + bench_func=train_xgboost, + ), + AlgorithmPair( + xgb, + xgb, + shared_args={ + "tree_method": "hist", + "n_estimators": 100, + "use_quantile_dmatrix": True, + "max_bin": 256, + }, + cpu_args={"device": "cpu"}, + cuml_args={"device": "cuda"}, + name="xgboost-native-regression-quantile", + accepts_labels=False, + setup_cpu_func=_build_xgboost_regressor_for_training, + setup_cuml_func=_build_xgboost_regressor_for_training, + accuracy_function=metrics.r2_score, + bench_func=train_xgboost, + ), AlgorithmPair( sklearn.manifold.TSNE, cuml.manifold.TSNE, diff --git a/python/cuml/cuml/benchmark/bench_helper_funcs.py b/python/cuml/cuml/benchmark/bench_helper_funcs.py index 11fa11d462..de85a1e694 100644 --- a/python/cuml/cuml/benchmark/bench_helper_funcs.py +++ b/python/cuml/cuml/benchmark/bench_helper_funcs.py @@ -20,6 +20,11 @@ from cuml.internals.device_type import DeviceType from cuml.manifold import UMAP +try: + import xgboost as xgb +except ImportError: + xgb = None + def call(m, func_name, X, y=None): def unwrap_and_get_args(func): @@ -82,6 +87,43 @@ def fit_kneighbors(m, x, y=None): call(m, "fit_kneighbors", x, y) +def train_xgboost(m, x, y=None): + """ + Bench function for XGBoost that times the training phase. + + This function is designed to work with XGBoostTrainWrapper instances + that have been pre-configured by the setup functions. It calls the + retrain() method which performs the actual training that gets timed. + + Parameters + ---------- + m : XGBoostTrainWrapper + x : array-like + y : array-like, optional + + Returns + ------- + booster : xgboost.Booster or None + The trained XGBoost booster model + + Notes + ----- + The x and y parameters are present for interface consistency with other + benchmark functions but are not used since the training data is already + contained in the pre-built DMatrix within the wrapper. This is because + The benchmarking framework always calls bench_func with + (model, data[0], data[1]) or (model, data[0]). + """ + if hasattr(m, "retrain"): + return m.retrain() + else: + raise ValueError( + f"Expected XGBoostTrainWrapper with 'retrain' method, " + f"but got {type(m).__name__}. Ensure the setup function " + f"(_build_xgboost_*_for_training) was called correctly." + ) + + def _training_data_to_numpy(X, y): """Convert input training data into numpy format""" if isinstance(X, np.ndarray): @@ -372,3 +414,145 @@ def _build_mnmg_umap(m, data, args, tmpdir): local_model.fit(X) return m(client=client, model=local_model, **args) + + +class XGBoostTrainWrapper: + """Helper class for benchmarking XGBoost training phase""" + + def __init__( + self, + dtrain, + params, + num_boost_round, + device="cpu", + ): + self.dtrain = dtrain + self.params = params + self.num_boost_round = num_boost_round + self.device = device + self.booster = None + + def retrain(self): + """Retrain the model - this is what gets timed""" + + debug_mode = os.environ.get("XGBOOST_DEBUG", "0") == "1" + if debug_mode: + print( + f"[XGBoost Retrain] Device: {self.device}, Rounds: {self.num_boost_round}" + ) + + self.booster = xgb.train( + self.params, self.dtrain, self.num_boost_round + ) + return self.booster + + +def _build_xgboost_for_training( + m, data, args, tmpdir, task_type="classification" +): + """ + Common setup function for XGBoost training - prepares but doesn't train (for timing training). + + Args: + m: Model (unused but required for interface consistency) + data: Training data tuple (features, labels) + args: Configuration arguments + tmpdir: Temporary directory (unused) + task_type: Either "classification" or "regression" + + Returns: + XGBoostTrainWrapper ready for training + """ + + train_data, train_label = _training_data_to_numpy(data[0], data[1]) + + args_copy = args.copy() + + use_quantile_dmatrix = args_copy.pop("use_quantile_dmatrix", False) + max_bin = args_copy.pop("max_bin", 256) + num_boost_round = args_copy.pop("n_estimators", 100) + device = args_copy.pop("device", "cpu") + + debug_mode = os.environ.get("XGBOOST_DEBUG", "0") == "1" + if debug_mode: + task_name = ( + "Classifier" if task_type == "classification" else "Regressor" + ) + print(f"[XGBoost Setup {task_name} for Training] Device: {device}") + + # Task-specific label processing and parameter setup + if task_type == "classification": + unique_labels = np.unique(train_label) + n_classes = len(unique_labels) + + label_map = { + old_label: new_label + for new_label, old_label in enumerate(unique_labels) + } + train_label_normalized = np.array( + [label_map[label] for label in train_label] + ) + + # Determine objective based on number of classes + if n_classes == 2: + objective = "binary:logistic" + eval_metric = "error" + else: + objective = "multi:softmax" + eval_metric = "merror" + + params = { + "objective": objective, + "eval_metric": eval_metric, + "device": device, + } + + if n_classes > 2: + params["num_class"] = n_classes + else: + train_label_normalized = train_label + + params = { + "objective": "reg:squarederror", + "eval_metric": "rmse", + "device": device, + } + + if use_quantile_dmatrix: + dtrain = xgb.QuantileDMatrix( + train_data, label=train_label_normalized, max_bin=max_bin + ) + else: + dtrain = xgb.DMatrix(train_data, label=train_label_normalized) + + params.update(args_copy) + + if debug_mode: + task_name = ( + "Classifier" if task_type == "classification" else "Regressor" + ) + print( + f"[XGBoost Setup {task_name} for Training] Final params: {params}" + ) + + # Return wrapper that can retrain (training will be timed) + return XGBoostTrainWrapper( + dtrain, + params, + num_boost_round, + device=device, + ) + + +def _build_xgboost_classifier_for_training(m, data, args, tmpdir): + """Setup function for XGBoost classification - prepares but doesn't train (for timing training)""" + return _build_xgboost_for_training( + m, data, args, tmpdir, task_type="classification" + ) + + +def _build_xgboost_regressor_for_training(m, data, args, tmpdir): + """Setup function for XGBoost regression - prepares but doesn't train (for timing training)""" + return _build_xgboost_for_training( + m, data, args, tmpdir, task_type="regression" + )