From 70c35517490190336120f43b74ab6750a14f1122 Mon Sep 17 00:00:00 2001 From: Tomas Johannesson Date: Sat, 27 Jul 2024 15:02:08 -0500 Subject: [PATCH 1/3] Kernel PCA Python and benchmarking --- BUILD.md | 2 +- notebooks/tools/cuml_benchmarks.ipynb | 24 + python/cuml/CMakeLists.txt | 2 + python/cuml/cuml/benchmark/algorithms.py | 7 + python/cuml/cuml/decomposition/utils.pxd | 23 +- python/cuml/cuml/experimental/__init__.py | 1 + .../experimental/decomposition/CMakeLists.txt | 26 + .../experimental/decomposition/__init__.py | 17 + .../experimental/decomposition/kpca.pyx | 502 ++++++++++++++++++ python/cuml/cuml/tests/test_kpca.py | 296 +++++++++++ python/cuml/cuml/tests/test_pca.py | 9 +- 11 files changed, 904 insertions(+), 5 deletions(-) create mode 100644 python/cuml/cuml/experimental/experimental/decomposition/CMakeLists.txt create mode 100644 python/cuml/cuml/experimental/experimental/decomposition/__init__.py create mode 100644 python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx create mode 100644 python/cuml/cuml/tests/test_kpca.py diff --git a/BUILD.md b/BUILD.md index 4bc8310407..0139f5abf9 100644 --- a/BUILD.md +++ b/BUILD.md @@ -91,7 +91,7 @@ $ pytest --ignore=cuml/tests/dask --ignore=cuml/tests/test_nccl.py If you want a list of the available Python tests: ```bash -$ pytest cuML/tests --collect-only +$ pytest cuml/tests --collect-only ``` ### Manual Process diff --git a/notebooks/tools/cuml_benchmarks.ipynb b/notebooks/tools/cuml_benchmarks.ipynb index 5a7039021a..5f5e30bf5d 100644 --- a/notebooks/tools/cuml_benchmarks.ipynb +++ b/notebooks/tools/cuml_benchmarks.ipynb @@ -531,6 +531,30 @@ "execute_benchmark(\"PCA\", runner)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### KernelPCA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "runner = cuml.benchmark.runners.SpeedupComparisonRunner(\n", + " bench_rows=[400, 800, 1600, 3200, 6400, 12800],\n", + " bench_dims=WIDE_FEATURES,\n", + " dataset_name=DATA_CLASSIFICATION,\n", + " input_type=INPUT_TYPE,\n", + " n_reps=N_REPS\n", + ")\n", + "\n", + "execute_benchmark(\"KernelPCA\", runner)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/cuml/CMakeLists.txt b/python/cuml/CMakeLists.txt index f2541f7f04..e1d00c5585 100644 --- a/python/cuml/CMakeLists.txt +++ b/python/cuml/CMakeLists.txt @@ -109,6 +109,7 @@ if(CUML_CPU) set(CUML_ALGORITHMS "linearregression") list(APPEND CUML_ALGORITHMS "pca") + list(APPEND CUML_ALGORITHMS "kpca") list(APPEND CUML_ALGORITHMS "tsvd") list(APPEND CUML_ALGORITHMS "elasticnet") list(APPEND CUML_ALGORITHMS "logisticregression") @@ -179,6 +180,7 @@ add_subdirectory(cuml/svm) add_subdirectory(cuml/tsa) add_subdirectory(cuml/experimental/linear_model) +add_subdirectory(cuml/experimental/decomposition) if(DEFINED cython_lib_dir) rapids_cython_add_rpath_entries(TARGET cuml PATHS "${cython_lib_dir}") diff --git a/python/cuml/cuml/benchmark/algorithms.py b/python/cuml/cuml/benchmark/algorithms.py index be9b5ab841..f1e9a25b62 100644 --- a/python/cuml/cuml/benchmark/algorithms.py +++ b/python/cuml/cuml/benchmark/algorithms.py @@ -245,6 +245,13 @@ def all_algorithms(): name="PCA", accepts_labels=False, ), + AlgorithmPair( + sklearn.decomposition.KernelPCA, + cuml.experimental.KernelPCA, + shared_args=dict(), + name="KernelPCA", + accepts_labels=False, + ), AlgorithmPair( sklearn.decomposition.TruncatedSVD, cuml.decomposition.tsvd.TruncatedSVD, diff --git a/python/cuml/cuml/decomposition/utils.pxd b/python/cuml/cuml/decomposition/utils.pxd index 98134b9437..33896f5222 100644 --- a/python/cuml/cuml/decomposition/utils.pxd +++ b/python/cuml/cuml/decomposition/utils.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,19 @@ from libcpp cimport bool ctypedef int underlying_type_t_solver +cdef extern from "raft/distance/distance_types.hpp" namespace "raft::distance::kernels" nogil: + enum KernelType: + LINEAR, + POLYNOMIAL, + RBF, + TANH + + cdef struct KernelParams: + KernelType kernel + int degree + double gamma + double coef0 + cdef extern from "cuml/decomposition/params.hpp" namespace "ML" nogil: ctypedef enum solver "ML::solver": @@ -41,3 +54,11 @@ cdef extern from "cuml/decomposition/params.hpp" namespace "ML" nogil: cdef cppclass paramsPCA(paramsTSVD): bool copy bool whiten + + cdef cppclass paramsKPCA(paramsTSVD): + KernelParams kernel + size_t n_training_samples + bool copy + bool remove_zero_eig + bool fit_inverse_transform + diff --git a/python/cuml/cuml/experimental/__init__.py b/python/cuml/cuml/experimental/__init__.py index 6bafd70ac7..7c9098405c 100644 --- a/python/cuml/cuml/experimental/__init__.py +++ b/python/cuml/cuml/experimental/__init__.py @@ -1 +1,2 @@ from cuml.experimental.fil import ForestInference +from cuml.experimental.decomposition import KernelPCA diff --git a/python/cuml/cuml/experimental/experimental/decomposition/CMakeLists.txt b/python/cuml/cuml/experimental/experimental/decomposition/CMakeLists.txt new file mode 100644 index 0000000000..fbdd9ee7aa --- /dev/null +++ b/python/cuml/cuml/experimental/experimental/decomposition/CMakeLists.txt @@ -0,0 +1,26 @@ +# ============================================================================= +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + + + +set(cython_sources "") +add_module_gpu_default("kpca.pyx" ${kpca_algo} ${decomposition_algo}) + +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${cuml_sg_libraries}" + MODULE_PREFIX experimental_ + ASSOCIATED_TARGETS cuml +) diff --git a/python/cuml/cuml/experimental/experimental/decomposition/__init__.py b/python/cuml/cuml/experimental/experimental/decomposition/__init__.py new file mode 100644 index 0000000000..5d9a64d3c4 --- /dev/null +++ b/python/cuml/cuml/experimental/experimental/decomposition/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cuml.experimental.decomposition.kpca import KernelPCA diff --git a/python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx b/python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx new file mode 100644 index 0000000000..9bbb12c5c2 --- /dev/null +++ b/python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx @@ -0,0 +1,502 @@ +# distutils: language = c++ + +from cuml.internals.safe_imports import cpu_only_import +np = cpu_only_import('numpy') +from cuml.internals.safe_imports import gpu_only_import +cp = gpu_only_import('cupy') +cupyx = gpu_only_import('cupyx') +scipy = cpu_only_import('scipy') + +rmm = gpu_only_import('rmm') + +from libc.stdint cimport uintptr_t + +import cuml.internals +from cuml.internals.array import CumlArray +from cuml.internals.base import UniversalBase +from cuml.common.doc_utils import generate_docstring +from cuml.internals.input_utils import input_to_cuml_array +from cuml.internals.input_utils import input_to_cupy_array +from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common import using_output_type +from cuml.prims.stats import cov +from cuml.internals.input_utils import sparse_scipy_to_cp +from cuml.common.exceptions import NotFittedError +from cuml.internals.mixins import FMajorInputTagMixin +from cuml.internals.api_decorators import device_interop_preparation +from cuml.internals.api_decorators import enable_device_interop +from cuml.internals import logger + +IF GPUBUILD == 1: + from enum import IntEnum + from cython.operator cimport dereference as deref + from cuml.decomposition.utils cimport * + from pylibraft.common.handle cimport handle_t + + cdef extern from "cuml/decomposition/kpca.hpp" namespace "ML": + cdef void kpcaFit(handle_t& handle, + float *input, + float *eigenvectors, + float *eigenvalues, + int *n_components, + const paramsKPCA &prms) except + + + cdef void kpcaFit(handle_t& handle, + double *input, + double *eigenvectors, + double *eigenvalues, + int *n_components, + const paramsKPCA &prms) except + + + + cdef void kpcaTransformWithFitData(handle_t& handle, + float *eigenvectors, + float *eigenvalues, + float *trans_input, + const paramsKPCA &prms) except + + + cdef void kpcaTransformWithFitData(handle_t& handle, + double *eigenvectors, + double *eigenvalues, + double *trans_input, + const paramsKPCA &prms) except + + + cdef void kpcaTransform(handle_t& handle, + float *fit_input, + float *input, + float *eigenvectors, + float *eigenvalues, + float *trans_input, + const paramsKPCA &prms) except + + + cdef void kpcaTransform(handle_t& handle, + double *fit_input, + double *input, + double *eigenvectors, + double *eigenvalues, + double *trans_input, + const paramsKPCA &prms) except + + + class Solver(IntEnum): + COV_EIG_DQ = solver.COV_EIG_DQ + COV_EIG_JACOBI = solver.COV_EIG_JACOBI + + +class KernelPCA(UniversalBase, + FMajorInputTagMixin): + + """ + KernelPCA (Kernel Principal Component Analysis) is an extension of PCA + that allows for non-linear dimensionality reduction through the use of + kernel methods. It projects the data into a higher-dimensional space + where it becomes linearly separable, and then applies PCA to capture the + most variance in the data. + + cuML's KernelPCA expects an array-like object or cuDF DataFrame, and + supports various kernels such as linear, polynomial, RBF, and sigmoid. + + Examples + -------- + .. code-block:: python + + >>> # Importing KernelPCA + >>> from cuml.experimental.decomposition import KernelPCA + + >>> import cudf + >>> import cupy as cp + + >>> gdf_float = cudf.DataFrame() + >>> gdf_float['0'] = cp.asarray([1.0, 2.0, 5.0], dtype=cp.float32) + >>> gdf_float['1'] = cp.asarray([4.0, 2.0, 1.0], dtype=cp.float32) + >>> gdf_float['2'] = cp.asarray([4.0, 2.0, 1.0], dtype=cp.float32) + + >>> kpca_float = KernelPCA(n_components=2, kernel='rbf', gamma=15) + >>> kpca_float.fit(gdf_float) + KernelPCA() + + >>> print(f'components: {kpca_float.eigenvalues_}') # doctest: +SKIP + components: + 0 1.0 + 1 1.0 + >>> print(f'eigen vectors: {kpca_float.eigenvectors_}') # doctest: +SKIP + eigen vectors: [...] + 0 1 + 0 -0.408248 0.707107 + 1 -0.408248 -0.707107 + 2 0.816497 0.000000 + >>> trans_gdf_float = kpca_float.transform(gdf_float) + >>> print(f'Transformed: {trans_gdf_float}') # doctest: +SKIP + Transformed: + 0 1 + 0 -0.408248 7.071068e-01 + 1 -0.408248 -7.071068e-01 + 2 0.816497 -1.284374e-08 + + Parameters + ---------- + n_components : int, optional (default=None) + The number of components to keep. If None, all non zero eigenvalues are kept. + kernel : {'linear', 'poly', 'rbf', 'sigmoid'}, optional (default='linear') + Kernel to be used in the algorithm. + gamma : float, optional (default=None) + Kernel coefficient for 'rbf', 'poly', and 'sigmoid'. If None, 1/n_features is used. + degree : int, optional (default=3) + Degree for the polynomial kernel. Ignored by other kernels. + coef0 : float, optional (default=1) + Independent term in kernel function. It is only significant in 'poly' and 'sigmoid'. + kernel_params : dict, optional (default=None) + Parameters (keyword arguments) and values for kernel passed as callable object. + alpha : float, optional (default=1.0) + Hyperparameter of the ridge regression that learns the inverse transform. Inverse transform not supported in cuML. + fit_inverse_transform : bool, optional (default=False) + Not supported in cuML. + eigen_solver : {'auto', 'full', 'jacobi'}, optional (default='auto') + Select eigensolver to use. + tol : float, optional (default=0) + Convergence tolerance for arpack. + max_iter : int, optional (default=None) + Not supported in available eigen solvers. + remove_zero_eig : bool, optional (default=False) + If True, then all components with zero eigenvalues are removed + random_state : int or None, optional (default=None) + Seed for the random number generator. Not supported in available eigen solvers. + copy_X : bool, optional (default=True) + If True, input X is copied and stored. Otherwise, X may be overwritten. + verbose : int or bool, optional (default=False) + Enable verbose output. If True, output is printed. If False, no output. + output_type : {'input', 'array', 'dataframe', 'series', 'df_obj', 'numba', 'cupy', 'numpy', 'cudf', 'pandas'}, optional (default=None) + Return results and set estimator attributes to the indicated output type. If None, the output type set at the module level (`cuml.global_settings.output_type`) will be used. + + Attributes + ---------- + eigenvectors_ : array + Eigenvectors in the transformed space. + eigenvalues_ : array + Eigenvalues in the transformed space. + X_fit_ : array + Data used for fitting. + gamma_ : float + Kernel coefficient for 'rbf', 'poly', and 'sigmoid'. + n_features_in_ : int + Number of features in the input data. + n_samples_ : int + Number of samples in the input data. + feature_names_in_ : list + Names of the features in the input data. + n_components_ : int + Number of components to keep. If None, all non zero eigenvalues are kept. + + Notes + ----- + KernelPCA (KPCA) is a non-linear extension of PCA, which allows for the capture + of complex, non-linear structures in the data. This makes KPCA suitable for datasets + where linear assumptions are insufficient to capture the underlying patterns. + It employs kernel methods to project data into a higher-dimensional space where + it becomes linearly separable, thus retaining more meaningful structure. + + **Applications of KernelPCA** + + KernelPCA is widely used for feature extraction and dimensionality reduction + in various domains. It is particularly effective for data that exhibits + non-linear relationships, such as in image denoising, pattern recognition, + and pre-processing data for machine learning algorithms. It has been applied + to gene expression data to uncover complex biological patterns, and in + image processing to improve the performance of object recognition systems. + + + For additional docs, see `scikit-learn's KernelPCA + `_. + """ + + _cpu_estimator_import_path = 'sklearn.decomposition.KernelPCA' + eigenvalues_ = CumlArrayDescriptor(order='F') + eigenvectors_ = CumlArrayDescriptor(order='F') + trans_input_ = CumlArrayDescriptor(order='F') + + @device_interop_preparation + def __init__(self, *, handle=None, n_components=None, kernel='linear', gamma=None, + degree=3, coef0=1, kernel_params=None, alpha=1.0, + fit_inverse_transform=False, eigen_solver='auto', tol=0, + max_iter=None, iterated_power=15, remove_zero_eig=False, n_jobs=None, + random_state=None, copy_X=True, verbose=False, output_type=None): + if fit_inverse_transform: + raise NotImplementedError("Inverse transform is not supported") + if random_state is not None: + raise NotImplementedError("Random state is not supported in available eigen solvers") + if n_jobs is not None and n_jobs != -1: + raise NotImplementedError("n_jobs does not apply to this algorithm") + if max_iter is not None: + raise NotImplementedError("max_iter is not supported in available eigen solvers. Use iterated_power for Jacobi solver") + super().__init__(handle=handle, + verbose=verbose, + output_type=output_type) + self.copy_X = copy_X + self.max_iter = max_iter + self.iterated_power = iterated_power + self.n_components_ = n_components + self.remove_zero_eig = remove_zero_eig + self.random_state = random_state + self.eigen_solver = eigen_solver + self.tol = tol + self.kernel = kernel + self.c_kernel = self._get_c_kernel(kernel) + self.c_algorithm = self._get_algorithm_c_name(self.eigen_solver) + self.gamma_ = gamma + self.degree = degree + self.coef0 = coef0 + self.alpha = alpha + self.fit_inverse_transform = fit_inverse_transform + + self.trans_input_ = None + self.eigenvectors_ = None + self.eigenvalues_ = None + + def _get_c_kernel(self, kernel): + """ + Get KernelType from the kernel string. + + Parameters + ---------- + kernel: string, ('linear', 'poly', 'rbf', or 'sigmoid') + """ + return { + 'linear': LINEAR, + 'poly': POLYNOMIAL, + 'rbf': RBF, + 'sigmoid': TANH + }[kernel] + + def _get_algorithm_c_name(self, algorithm): + IF GPUBUILD == 1: + algo_map = { + 'full': Solver.COV_EIG_DQ, + 'auto': Solver.COV_EIG_DQ, + 'jacobi': Solver.COV_EIG_JACOBI + } + if algorithm not in algo_map: + msg = "algorithm {!r} is not supported" + raise TypeError(msg.format(algorithm)) + + return algo_map[algorithm] + + def _build_params(self, n_rows, n_cols): + IF GPUBUILD == 1: + cdef paramsKPCA *params = new paramsKPCA() + params.n_components = min(self.n_components_ or n_rows, n_rows) + params.n_training_samples = n_rows + params.n_rows = n_rows + params.n_cols = n_cols + params.n_iterations = self.iterated_power + params.tol = self.tol + params.verbose = self.verbose + params.remove_zero_eig = self.remove_zero_eig or self.n_components_ is None + params.algorithm = ( ( + self.c_algorithm)) + params.fit_inverse_transform = self.fit_inverse_transform + params.kernel = self._get_kernel_params(n_cols) + return params + + def _initialize_arrays(self, n_rows, n_cols): + # Will be resized to (n_components) after fit + self.eigenvalues_ = CumlArray.zeros((n_rows), + dtype=self.dtype) + # Will be resized to (n_rows, n_components) after fit + self.eigenvectors_ = CumlArray.zeros((n_rows, n_rows), + dtype=self.dtype) + + + @generate_docstring(X='dense') + @enable_device_interop + def fit(self, X, y=None) -> "KernelPCA": + """ + Fit the model with X. y is currently ignored. + + """ + if self.copy_X: + self.X_fit_ = X.copy() + else: + self.X_fit_ = X + self.X_m, self.n_samples_, self.n_features_in_, self.dtype = \ + input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) + if self.n_samples_ < 2: raise ValueError('n_samples must be greater than 1') + if self.n_features_in_ < 1: raise ValueError('n_features_in_ must be greater than 0') + cdef uintptr_t _input_ptr = self.X_m.ptr + self.feature_names_in_ = self.X_m.index + IF GPUBUILD == 1: + cdef paramsKPCA *params = \ + self._build_params(self.n_samples_, self.n_features_in_) + + + # Calling _initialize_arrays, guarantees everything is CumlArray + self._initialize_arrays(params.n_rows, params.n_cols) + + cdef uintptr_t eigenvectors_ptr = self.eigenvectors_.ptr + + cdef uintptr_t eigenvalues_ptr = \ + self.eigenvalues_.ptr + cdef int components = (self.n_components_ or -1) + cdef int* component_ptr = &components + cdef handle_t* handle_ = self.handle.getHandle() + if self.dtype == np.float32: + kpcaFit(handle_[0], + _input_ptr, + eigenvectors_ptr, + eigenvalues_ptr, + component_ptr, + deref(params)) + else: + kpcaFit(handle_[0], + _input_ptr, + eigenvectors_ptr, + eigenvalues_ptr, + component_ptr, + deref(params)) + # make sure the previously scheduled gpu tasks are complete before the + # following transfers start + self.handle.sync() + self.n_components_ = components + self.eigenvalues_ = self.eigenvalues_[:components] + self.eigenvectors_ = self.eigenvectors_[:, :components] + return self + + @generate_docstring(X='dense', + return_values={'name': 'trans', + 'type': 'dense', + 'description': 'Transformed values', + 'shape': '(n_samples, n_components)'}) + @cuml.internals.api_base_return_array_skipall + @enable_device_interop + def fit_transform(self, X, y=None) -> CumlArray: + """ + Apply dimensionality reduction to X. + + X is projected on the first principal components previously extracted + from a training set. + + """ + self.fit(X) + IF GPUBUILD == 1: + cdef paramsKPCA *params = \ + self._build_params(self.n_samples_, self.n_features_in_) + + cdef uintptr_t eigenvectors_ptr = self.eigenvectors_.ptr + + cdef uintptr_t eigenvalues_ptr = \ + self.eigenvalues_.ptr + + t_input_data = \ + CumlArray.zeros((params.n_rows, params.n_components), + dtype=self.dtype.type, index=self.X_m.index) + cdef uintptr_t _trans_input_ptr = t_input_data.ptr + + cdef handle_t* handle_ = self.handle.getHandle() + if self.dtype.type == np.float32: + kpcaTransformWithFitData(handle_[0], + eigenvectors_ptr, + eigenvalues_ptr, + _trans_input_ptr, + deref(params)) + else: + kpcaTransformWithFitData(handle_[0], + eigenvectors_ptr, + eigenvalues_ptr, + _trans_input_ptr, + deref(params)) + # make sure the previously scheduled gpu tasks are complete before the + # following transfers start + self.handle.sync() + return t_input_data + + @enable_device_interop + def transform(self, X, convert_dtype=False) -> CumlArray: + """ + Apply dimensionality reduction to X. + + X is projected on the first principal components previously extracted + from a training set. + + """ + self._check_is_fitted('eigenvectors_') + cdef uintptr_t _fit_input_ptr = self.X_m.ptr + + dtype = self.eigenvectors_.dtype + + X_m, _n_rows, _n_cols, dtype = \ + input_to_cuml_array(X, check_dtype=dtype, + convert_to_dtype=(dtype if convert_dtype + else None), + check_cols=self.n_features_in_) + if _n_cols != self.n_features_in_: + raise ValueError("Number of columns in input must match the " + "number of columns in the training data") + if _n_rows < 1: + raise ValueError("Number of rows in input must be greater than 0") + cdef uintptr_t _input_ptr = X_m.ptr + + IF GPUBUILD == 1: + cdef paramsKPCA params + params.n_training_samples = self.n_samples_ + params.n_components = len(self.eigenvalues_) + params.n_rows = _n_rows + params.n_cols = _n_cols + params.kernel = self._get_kernel_params(_n_cols) + t_input_data = \ + CumlArray.zeros((params.n_rows, params.n_components), + dtype=dtype.type, index=X_m.index) + + cdef uintptr_t _trans_input_ptr = t_input_data.ptr + cdef uintptr_t eigenvalues_ptr = self.eigenvalues_.ptr + cdef uintptr_t eigenvectors_ptr = \ + self.eigenvectors_.ptr + + cdef handle_t* handle_ = self.handle.getHandle() + if dtype.type == np.float32: + kpcaTransform(handle_[0], + _fit_input_ptr, + _input_ptr, + eigenvectors_ptr, + eigenvalues_ptr, + _trans_input_ptr, + params) + else: + kpcaTransform(handle_[0], + _fit_input_ptr, + _input_ptr, + eigenvectors_ptr, + eigenvalues_ptr, + _trans_input_ptr, + params) + + # make sure the previously scheduled gpu tasks are complete before the + # following transfers start + self.handle.sync() + + return t_input_data + + def _get_kernel_params(self, n_cols): + """ Wrap the kernel parameters in a KernelParams object """ + cdef KernelParams _kernel_params + if not self.gamma_: + self.gamma_ = 1 / n_cols + _kernel_params.kernel = self.c_kernel + _kernel_params.degree = self.degree + _kernel_params.gamma = self.gamma_ + _kernel_params.coef0 = self.coef0 + return _kernel_params + + def get_param_names(self): + return super().get_param_names() + \ + ["copy_X", "iterated_power", "n_components", "eigen_solver", "tol", + "random_state", "kernel", "gamma", "degree", "coef0", "alpha", + "fit_inverse_transform", "remove_zero_eig", "kernel_params", "max_iter"] + + def _check_is_fitted(self, attr): + if not hasattr(self, attr) or (getattr(self, attr) is None): + msg = ("This instance is not fitted yet. Call 'fit' " + "with appropriate arguments before using this estimator.") + raise NotFittedError(msg) + + + def get_attr_names(self): + return ['eigenvectors_', 'eigenvalues_', 'n_components_', 'X_fit_', + 'n_samples_', 'n_features_in_', 'feature_names_in_', 'gamma_'] diff --git a/python/cuml/cuml/tests/test_kpca.py b/python/cuml/cuml/tests/test_kpca.py new file mode 100644 index 0000000000..fa052c2cec --- /dev/null +++ b/python/cuml/cuml/tests/test_kpca.py @@ -0,0 +1,296 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cuml.common.exceptions import NotFittedError +from sklearn.datasets import make_blobs +from sklearn.decomposition import KernelPCA as skKernelPCA +from sklearn.datasets import make_multilabel_classification +from sklearn import datasets +from sklearn.model_selection import train_test_split + +from cuml.testing.utils import ( + get_handle, + array_equal, + unit_param, + quality_param, + stress_param, +) +from cuml.experimental.decomposition import KernelPCA as cuKernelPCA +import pytest +from cuml.internals.safe_imports import gpu_only_import +from cuml.internals.safe_imports import cpu_only_import +from cuml.internals import logger + + +np = cpu_only_import("numpy") +cp = gpu_only_import("cupy") +cupyx = gpu_only_import("cupyx") + + +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) +@pytest.mark.parametrize("input_type", ["ndarray"]) +@pytest.mark.parametrize("use_handle", [True, False]) +@pytest.mark.parametrize( + "name", [unit_param(None), quality_param("digits"), stress_param("blobs")] +) +@pytest.mark.parametrize("kernel", ["linear", "poly", "rbf", "sigmoid"]) +def test_kpca_fit(datatype, input_type, name, use_handle, kernel): + if name == "blobs": + pytest.skip("fails when using blobs dataset") + X, y = make_blobs(n_samples=25000, n_features=1000, random_state=0) + + elif name == "digits": + X, _ = datasets.load_digits(return_X_y=True) + + else: + X, Y = make_multilabel_classification( + n_samples=500, + n_classes=2, + n_labels=1, + allow_unlabeled=False, + random_state=1, + ) + X = X.astype(datatype) + skpca = skKernelPCA(kernel=kernel, eigen_solver="dense", n_components=2) + skpca.fit(X) + handle, stream = get_handle(use_handle) + cupca = cuKernelPCA(handle=handle, kernel=kernel, n_components=2) + cupca.fit(X) + cupca.handle.sync() + eigvals = getattr(cupca, "eigenvalues_") + for attr in [ + "eigenvectors_", + "eigenvalues_", + ]: + cuml_res = getattr(cupca, attr) + + skl_res = getattr(skpca, attr) + assert array_equal( + cuml_res, skl_res, 1e-1, total_tol=1e-1, with_sign=True + ) + + +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) +@pytest.mark.parametrize("input_type", ["ndarray"]) +@pytest.mark.parametrize("use_handle", [True, False]) +@pytest.mark.parametrize( + "name", [unit_param(None), quality_param("iris"), stress_param("blobs")] +) +@pytest.mark.parametrize("kernel", ["linear", "poly", "rbf", "sigmoid"]) +def test_kpca_fit_then_transform( + datatype, input_type, name, use_handle, kernel +): + blobs_n_samples = 25000 + if name == "blobs" and pytest.max_gpu_memory < 32: + if pytest.adapt_stress_test: + blobs_n_samples = int(blobs_n_samples * pytest.max_gpu_memory / 32) + else: + pytest.skip( + "Insufficient GPU memory for this test." + "Re-run with 'CUML_ADAPT_STRESS_TESTS=True'" + ) + + if name == "blobs": + X, y = make_blobs( + n_samples=blobs_n_samples, n_features=1000, random_state=0 + ) + + elif name == "iris": + iris = datasets.load_iris() + X = iris.data + + else: + X, Y = make_multilabel_classification( + n_samples=500, + n_classes=2, + n_labels=1, + allow_unlabeled=False, + random_state=1, + ) + + X = X.astype(datatype) + if name != "blobs": + skpca = skKernelPCA(n_components=2, kernel=kernel) + skpca.fit(X) + X_sk = skpca.transform(X) + + handle, stream = get_handle(use_handle) + cupca = cuKernelPCA(n_components=2, handle=handle, kernel=kernel) + + cupca.fit(X) + X_cu = cupca.transform(X) + cupca.handle.sync() + + if name != "blobs": + assert array_equal(X_cu, X_sk, 1e-1, total_tol=1e-1, with_sign=True) + assert X_sk.shape[0] == X_cu.shape[0] + assert X_sk.shape[1] == X_cu.shape[1] + + +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) +@pytest.mark.parametrize("input_type", ["ndarray"]) +@pytest.mark.parametrize("use_handle", [True, False]) +@pytest.mark.parametrize( + "name", [unit_param(None), quality_param("iris"), stress_param("blobs")] +) +@pytest.mark.parametrize("kernel", ["linear", "poly", "rbf", "sigmoid"]) +def test_kpca_fit_transform(datatype, input_type, name, use_handle, kernel): + blobs_n_samples = 25000 + if name == "blobs" and pytest.max_gpu_memory < 32: + if pytest.adapt_stress_test: + blobs_n_samples = int(blobs_n_samples * pytest.max_gpu_memory / 32) + else: + pytest.skip( + "Insufficient GPU memory for this test." + "Re-run with 'CUML_ADAPT_STRESS_TESTS=True'" + ) + + if name == "blobs": + X, y = make_blobs( + n_samples=blobs_n_samples, n_features=1000, random_state=0 + ) + + elif name == "iris": + iris = datasets.load_iris() + X = iris.data + + else: + X, Y = make_multilabel_classification( + n_samples=500, + n_classes=2, + n_labels=1, + allow_unlabeled=False, + random_state=1, + ) + + X = X.astype(datatype) + if name != "blobs": + skpca = skKernelPCA(n_components=2, kernel=kernel) + X_sk = skpca.fit_transform(X) + + handle, stream = get_handle(use_handle) + cupca = cuKernelPCA(n_components=2, handle=handle, kernel=kernel) + + X_cu = cupca.fit_transform(X) + cupca.handle.sync() + + if name != "blobs": + assert array_equal(X_cu, X_sk, 1e-1, total_tol=1e-1, with_sign=True) + assert X_sk.shape[0] == X_cu.shape[0] + assert X_sk.shape[1] == X_cu.shape[1] + + +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) +@pytest.mark.parametrize("input_type", ["ndarray"]) +@pytest.mark.parametrize("use_handle", [True, False]) +@pytest.mark.parametrize( + "name", [unit_param(None)] +) # [unit_param(None), quality_param("iris")]) +@pytest.mark.parametrize("kernel", ["linear", "poly", "rbf", "sigmoid"]) +def test_kpca_fit_then_transform_on_test_train_split( + datatype, input_type, name, use_handle, kernel +): + if name == "iris": + iris = datasets.load_iris() + X = iris.data + else: + X, _ = make_multilabel_classification( + n_samples=500, + n_classes=2, + n_labels=1, + allow_unlabeled=False, + random_state=1, + ) + + X = X.astype(datatype) + X_train, X_test = train_test_split(X, random_state=0) + skpca = skKernelPCA(n_components=2, kernel=kernel) + skpca.fit(X_train) + X_test_sk = skpca.transform(X_test) + handle, stream = get_handle(use_handle) + cupca = cuKernelPCA(n_components=2, handle=handle, kernel=kernel) + cupca.fit(X_train) + X_test_cu = cupca.transform(X_test) + cupca.handle.sync() + assert array_equal( + X_test_cu, X_test_sk, 1e-1, total_tol=1e-1, with_sign=True + ) + assert X_test_sk.shape[0] == X_test_cu.shape[0] + assert X_test_sk.shape[1] == X_test_cu.shape[1] + + +@pytest.mark.parametrize("n_samples", [200]) +@pytest.mark.parametrize("n_features", [100, 300]) +def test_kpca_defaults(n_samples, n_features): + X, _ = make_multilabel_classification( + n_samples=n_samples, + n_features=n_features, + n_classes=2, + n_labels=1, + random_state=1, + ) + cupca = cuKernelPCA() + cupca.fit(X) + curesult = cupca.transform(X) + cupca.handle.sync() + + skpca = skKernelPCA() + skpca.fit(X) + skresult = skpca.transform(X) + + assert skpca.eigen_solver == cupca.eigen_solver + assert cupca.eigenvalues_.shape[0] == skpca.eigenvalues_.shape[0] + assert cupca.eigenvectors_.shape[0] == skpca.eigenvectors_.shape[0] + assert curesult.shape == skresult.shape + assert array_equal(curesult, skresult, 1e-3, with_sign=False) + + +@pytest.mark.parametrize("n_samples", [200]) +@pytest.mark.parametrize("n_features", [100, 300]) +def test_kpca_fit_transform_defaults(n_samples, n_features): + X, _ = make_multilabel_classification( + n_samples=n_samples, + n_features=n_features, + n_classes=2, + n_labels=1, + random_state=1, + ) + cupca = cuKernelPCA() + curesult = cupca.fit_transform(X) + cupca.handle.sync() + + skpca = skKernelPCA() + skresult = skpca.fit_transform(X) + + assert skpca.eigen_solver == cupca.eigen_solver + assert cupca.eigenvalues_.shape[0] == skpca.eigenvalues_.shape[0] + assert cupca.eigenvectors_.shape[0] == skpca.eigenvectors_.shape[0] + assert curesult.shape == skresult.shape + assert array_equal(curesult, skresult, 1e-3, with_sign=False) + + +def test_exceptions(): + # KernelPCA is not fitted + with pytest.raises(NotFittedError): + X = cp.random.random((10, 10)) + cuKernelPCA().transform(X) + + # Eigensolver arpack is supported in sklearn, but not in cuML + with pytest.raises(TypeError): + cuKernelPCA(eigen_solver="arpack") + + # fit_inverse_transform is not supported in cuML + with pytest.raises(NotImplementedError): + cuKernelPCA(fit_inverse_transform=True) diff --git a/python/cuml/cuml/tests/test_pca.py b/python/cuml/cuml/tests/test_pca.py index 10db9a4f7b..8e7b09d9a1 100644 --- a/python/cuml/cuml/tests/test_pca.py +++ b/python/cuml/cuml/tests/test_pca.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -58,6 +58,7 @@ def test_pca_fit(datatype, input_type, name, use_handle): allow_unlabeled=False, random_state=1, ) + X = X.astype(datatype) skpca = skPCA(n_components=2) skpca.fit(X) @@ -158,7 +159,7 @@ def test_pca_fit_then_transform(datatype, input_type, name, use_handle): allow_unlabeled=False, random_state=1, ) - + X = X.astype(datatype) if name != "blobs": skpca = skPCA(n_components=2) skpca.fit(X) @@ -213,6 +214,8 @@ def test_pca_fit_transform(datatype, input_type, name, use_handle): random_state=1, ) + X = X.astype(datatype) + if name != "blobs": skpca = skPCA(n_components=2) Xskpca = skpca.fit_transform(X) @@ -247,7 +250,7 @@ def test_pca_inverse_transform(datatype, input_type, name, use_handle, nrows): X = rng.randn(n, p) # spherical data X[:, 1] *= 0.00001 # make middle component relatively small X += [3, 4, 2] # make a large mean - + X = X.astype(datatype) handle, stream = get_handle(use_handle) cupca = cuPCA(n_components=2, handle=handle) From 15894db05955e7664e25db888fb4dade831fd001 Mon Sep 17 00:00:00 2001 From: Tomas Johannesson Date: Sat, 27 Jul 2024 18:56:38 -0500 Subject: [PATCH 2/3] kpca funtion comment update --- .../experimental/decomposition/kpca.pyx | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx b/python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx index 9bbb12c5c2..948df3ec10 100644 --- a/python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx +++ b/python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx @@ -309,8 +309,8 @@ class KernelPCA(UniversalBase, @enable_device_interop def fit(self, X, y=None) -> "KernelPCA": """ - Fit the model with X. y is currently ignored. - + Fit the model with X. + Param y is not used, present for API consistency by convention. """ if self.copy_X: self.X_fit_ = X.copy() @@ -368,11 +368,8 @@ class KernelPCA(UniversalBase, @enable_device_interop def fit_transform(self, X, y=None) -> CumlArray: """ - Apply dimensionality reduction to X. - - X is projected on the first principal components previously extracted - from a training set. - + Fit the model with X and apply the kernel-based dimensionality reduction on X. + Param y is not used, present for API consistency by convention. """ self.fit(X) IF GPUBUILD == 1: @@ -408,13 +405,13 @@ class KernelPCA(UniversalBase, return t_input_data @enable_device_interop - def transform(self, X, convert_dtype=False) -> CumlArray: """ - Apply dimensionality reduction to X. + Apply kernel-based dimensionality reduction to X. - X is projected on the first principal components previously extracted - from a training set. + X is projected into the kernel principal component space + learned from the training set. + Param y is not used, present for API consistency by convention. """ self._check_is_fitted('eigenvectors_') cdef uintptr_t _fit_input_ptr = self.X_m.ptr From 52eeeef9ed601092895cf782bd38aa99a4f78da6 Mon Sep 17 00:00:00 2001 From: Tomas Johannesson Date: Sun, 11 Aug 2024 13:55:33 -0500 Subject: [PATCH 3/3] fix incorrect path from merge and add transform() function back --- .../experimental/{experimental => }/decomposition/CMakeLists.txt | 0 .../experimental/{experimental => }/decomposition/__init__.py | 0 .../cuml/experimental/{experimental => }/decomposition/kpca.pyx | 1 + 3 files changed, 1 insertion(+) rename python/cuml/cuml/experimental/{experimental => }/decomposition/CMakeLists.txt (100%) rename python/cuml/cuml/experimental/{experimental => }/decomposition/__init__.py (100%) rename python/cuml/cuml/experimental/{experimental => }/decomposition/kpca.pyx (99%) diff --git a/python/cuml/cuml/experimental/experimental/decomposition/CMakeLists.txt b/python/cuml/cuml/experimental/decomposition/CMakeLists.txt similarity index 100% rename from python/cuml/cuml/experimental/experimental/decomposition/CMakeLists.txt rename to python/cuml/cuml/experimental/decomposition/CMakeLists.txt diff --git a/python/cuml/cuml/experimental/experimental/decomposition/__init__.py b/python/cuml/cuml/experimental/decomposition/__init__.py similarity index 100% rename from python/cuml/cuml/experimental/experimental/decomposition/__init__.py rename to python/cuml/cuml/experimental/decomposition/__init__.py diff --git a/python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx b/python/cuml/cuml/experimental/decomposition/kpca.pyx similarity index 99% rename from python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx rename to python/cuml/cuml/experimental/decomposition/kpca.pyx index 948df3ec10..8733f49f5a 100644 --- a/python/cuml/cuml/experimental/experimental/decomposition/kpca.pyx +++ b/python/cuml/cuml/experimental/decomposition/kpca.pyx @@ -405,6 +405,7 @@ class KernelPCA(UniversalBase, return t_input_data @enable_device_interop + def transform(self, X, convert_dtype=False) -> CumlArray: """ Apply kernel-based dimensionality reduction to X.