Skip to content

Conversation

@aamijar
Copy link
Member

@aamijar aamijar commented Nov 11, 2025

Resolves #7154

This PR adds support for TargetEncoder in cuml.accel. This feature was originally requested by the kaggle team.

There are API differences between cuml and sklearn's implementation of TargetEncoder and these differences must be handled in translating between cpu and gpu models.

cuML TargetEncoder param sklearn TargetEncoder param Transformation / Notes
n_folds cv Direct mapping
seed random_state If random_state is None, defaults to 42
smooth smooth If smooth == "auto", set to 1.0; else float(model.smooth)
split_method shuffle "random" if shuffle=True, otherwise "continuous"
output_type (no sklearn equivalent) Always "auto"
stat (no sklearn equivalent) Always "mean"
(no cuml equivalent) categories Always "auto"
(no cuml equivalent) target_type Always "continuous"

Testing upstream

./run-tests.sh -k "targetencoder"

Current failures

FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_n_features_in_after_fitting] - AssertionError: `TargetEncoder.fit()` does not set the `n_features_in_` attribute. You might want to use `sklearn.utils.validation.validate_data` instead of `check_array` in `TargetEncoder.fit()` which takes care of setting the attribute.
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_complex_data] - NotImplementedError: complex128 not supported
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_dtype_object] - TypeError: Cannot convert a floating of object type
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_estimators_empty_data_messages] - AssertionError: The estimator TargetEncoder does not raise a ValueError when an empty data is used to train. Perhaps use check_array in train.
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_estimators_pickle] - AttributeError: 'TargetEncoder' object has no attribute 'categories_'. Did you mean: 'categories'?
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_estimators_pickle(readonly_memmap=True)] - AttributeError: 'TargetEncoder' object has no attribute 'categories_'. Did you mean: 'categories'?
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_transformer_data_not_an_array] - TypeError: Input of type <class 'sklearn.utils.estimator_checks._NotAnArray'> is not cudf.Series, cudf.DataFrame or pandas.Series or pandas.DataFrameor cupy.ndarray or numpy.ndarray
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_transformer_general] - AssertionError: The transformer TargetEncoder does not raise an error when the number of features in transform is different from the number of features in fit.
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_transformer_general(readonly_memmap=True)] - AssertionError: The transformer TargetEncoder does not raise an error when the number of features in transform is different from the number of features in fit.
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_methods_sample_order_invariance] - AssertionError: 
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_methods_subset_invariance] - AssertionError: 
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_fit_idempotent] - AttributeError: 'TargetEncoder' object has no attribute 'set_params'. Did you mean: 'get_params'?
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_fit_check_is_fitted] - sklearn.exceptions.NotFittedError: Estimator fails to pass `check_is_fitted` even though it has been fit.
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_n_features_in] - AssertionError
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_fit1d] - AssertionError: Did not raise: [<class 'ValueError'>]
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_fit2d_predict1d] - KeyError: '__FEA__'
FAILED tests/test_common.py::test_estimators[TargetEncoder()-check_requires_y_none] - TypeError: Input of type <class 'NoneType'> is not cudf.Series, or pandas.Seriesor numpy.ndarrayor cupy.ndarray
FAILED tests/test_common.py::test_pandas_column_name_consistency[TargetEncoder()] - ValueError: Estimator does not have a feature_names_in_ attribute after fitting with a dataframe
FAILED tests/test_common.py::test_check_inplace_ensure_writeable[TargetEncoder()] - unittest.case.SkipTest: TargetEncoder doesn't require writeable input.
FAILED tests/test_docstring_parameters.py::test_fit_docstring_attributes[TargetEncoder-TargetEncoder] - AssertionError: assert False
============================================================================================== 20 failed, 41 passed, 41404 deselected, 28 warnings in 14.08s ==============================================================================================

Testing local

pytest test_sklearn_import_export.py -k "target_encoder"

Current failures

AttributeError: 'TargetEncoder' object has no attribute 'categories_'. Did you mean: 'categories'?
============================== short test summary info ==============================
FAILED test_sklearn_import_export.py::test_target_encoder - AttributeError: 'TargetEncoder' object has no attribute 'categories_'. Did you m...

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 11, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@github-actions github-actions bot added the Cython / Python Cython or Python issue label Nov 11, 2025
@aamijar aamijar added cuml-accel Issues related to cuml.accel non-breaking Non-breaking change feature request New feature or request labels Nov 11, 2025


class TargetEncoder:
class TargetEncoder(InteropMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For everything to work properly you'll also need to move TargetEncoder to be a subclass of Base. Looking at the current implementation, I think this would entail:

  • Adding Base to the base class list (it should go first, before the mixin)
  • Updating the definition of _get_param_names to also include super()._get_param_names() (please also move this definition to the top, as we've done on other estimators).
  • Ripping out the custom infra in the class like _get_output_type/get_params/.... Basically everything that's not there to implement fit/fit_transform/transform should be moved to use the Base infra.
  • Adding CumlArray return type annotations from transform/fit_transform to enable method type reflection
  • Possibly using a CumlArrayDescriptor to reflect fitted attributes, though from looking at the list in the sklearn docs I don't think that's necessary.
  • Ensuring we have adequate test coverage for this estimator so we're not unexpectedly breaking things. Since this wasn't a Base subclass and wasn't doing type reflection the way we do elsewhere I wouldn't be surprised if after this we see differences in behavior, but if we're moving towards our expected standard I'd view those "breaking changes" as more bugfixes since this estimator doesn't follow our conventions.

Overall it looks like there's a bunch of cleanup work to do in this estimator, making this ready for cuml-accel is not necessarily a light lift.

"target_mean_": to_cpu(self.mean),
"n_features_in_": len(self.train.columns) - 3,
**super()._attrs_to_cpu(model),
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_attrs_to_cpu and _attrs_from_cpu should really roundtrip all fitted state, right now it looks like you're pulling from categories_ in _attrs_from_cpu but not setting them in _attrs_to_cpu, which is a bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cuml-accel Issues related to cuml.accel Cython / Python Cython or Python issue feature request New feature or request non-breaking Non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for TargetEncoder in cuml.accel

2 participants