Skip to content

Commit d07722d

Browse files
committed
Remove base metaclass and old decorator names
1 parent d994ed4 commit d07722d

File tree

8 files changed

+42
-153
lines changed

8 files changed

+42
-153
lines changed

python/cuml/cuml/dask/common/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
from cuml.dask.common import parts_to_ranks
2121
from cuml.dask.common.input_utils import DistributedDataHandler
2222
from cuml.dask.common.utils import get_client, wait_and_raise_from_futures
23-
from cuml.internals import BaseMetaClass
2423
from cuml.internals.array import CumlArray
2524
from cuml.internals.base import Base
2625

2726

28-
class BaseEstimator(object, metaclass=BaseMetaClass):
27+
class BaseEstimator:
2928
def __init__(self, *, client=None, verbose=False, **kwargs):
3029
"""
3130
Constructor for distributed estimators.

python/cuml/cuml/internals/__init__.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,5 @@
22
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
33
# SPDX-License-Identifier: Apache-2.0
44
#
5-
from cuml.internals.base_helpers import BaseMetaClass, _tags_class_and_instance
6-
from cuml.internals.constants import CUML_WRAPPED_FLAG
75
from cuml.internals.internals import GraphBasedDimRedCallback
8-
from cuml.internals.outputs import (
9-
api_base_fit_transform,
10-
api_base_return_any,
11-
api_base_return_any_skipall,
12-
api_base_return_array,
13-
api_base_return_array_skipall,
14-
api_return_any,
15-
api_return_array,
16-
exit_internal_api,
17-
reflect,
18-
)
6+
from cuml.internals.outputs import exit_internal_api, reflect

python/cuml/cuml/internals/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from cuml.internals.outputs import check_output_type
1919

2020

21-
class Base(TagsMixin, metaclass=cuml.internals.BaseMetaClass):
21+
class Base(TagsMixin):
2222
"""
2323
Base class for all the ML algos. It handles some of the common operations
2424
across all algos. Every ML algo class exposed at cython level must inherit

python/cuml/cuml/internals/base_helpers.py

Lines changed: 0 additions & 94 deletions
This file was deleted.

python/cuml/cuml/internals/constants.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

python/cuml/cuml/internals/mixins.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from cuml._thirdparty._sklearn_compat import _to_new_tags
1010
from cuml.common.doc_utils import generate_docstring
11-
from cuml.internals.base_helpers import _tags_class_and_instance
1211
from cuml.internals.outputs import reflect
1312

1413
###############################################################################
@@ -45,6 +44,34 @@
4544
}
4645

4746

47+
class _tags_class_and_instance:
48+
"""
49+
Decorator for mixins to allow for dynamic and static _get_tags.
50+
In general, most methods are either dynamic or static, so this decorator
51+
is only meant to be used in the mixins _get_tags.
52+
"""
53+
54+
def __init__(self, _class, _instance=None):
55+
self._class = _class
56+
self._instance = _instance
57+
58+
def instance_method(self, _instance):
59+
"""
60+
Factory to create a _tags_class_and_instance instance method with
61+
the existing class associated.
62+
"""
63+
return _tags_class_and_instance(self._class, _instance)
64+
65+
def __get__(self, _instance, _class):
66+
# if the caller had no instance (i.e. it was a class) or there is no
67+
# instance associated we the method we return the class call
68+
if _instance is None or self._instance is None:
69+
return self._class.__get__(_class, None)
70+
71+
# otherwise return instance call
72+
return self._instance.__get__(_instance, _class)
73+
74+
4875
class TagsMixin:
4976
@_tags_class_and_instance
5077
def _get_tags(cls):

python/cuml/cuml/internals/outputs.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# TODO: Try to resolve circular import that makes this necessary:
1212
from cuml.internals import input_utils as iu
1313
from cuml.internals.array_sparse import SparseCumlArray
14-
from cuml.internals.constants import CUML_WRAPPED_FLAG
1514
from cuml.internals.global_settings import GlobalSettings
1615

1716
__all__ = (
@@ -353,9 +352,6 @@ def reflect(
353352
skip=skip,
354353
)
355354

356-
# TODO: remove this once auto-decorating is ripped out
357-
setattr(func, CUML_WRAPPED_FLAG, True)
358-
359355
sig = inspect.signature(func, follow_wrapped=True)
360356
has_self = "self" in sig.parameters
361357

@@ -437,28 +433,3 @@ def inner(*args, **kwargs):
437433
return res
438434

439435
return inner
440-
441-
442-
def api_return_array(input_arg=default, get_output_type=False):
443-
return reflect(array=None if not get_output_type else input_arg)
444-
445-
446-
def api_return_any():
447-
return reflect(array=None, skip=True)
448-
449-
450-
def api_base_return_any():
451-
return reflect(reset=True)
452-
453-
454-
def api_base_return_array(input_arg=default):
455-
return reflect(array="self" if input_arg is None else input_arg)
456-
457-
458-
def api_base_fit_transform():
459-
return reflect(reset=True)
460-
461-
462-
# TODO: investigate and remove these
463-
api_base_return_any_skipall = api_return_any()
464-
api_base_return_array_skipall = reflect

python/cuml/tests/test_cuml_descr_decor.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,30 @@ class DummyTestEstimator(cuml.Base):
4545
def _set_input(self, X):
4646
self.input_any_ = X
4747

48-
@cuml.internals.api_base_return_any()
48+
@cuml.internals.reflect(reset=True)
4949
def store_input(self, X):
5050
self.input_any_ = X
5151

52-
@cuml.internals.api_return_any()
52+
@cuml.internals.reflect(skip=True)
5353
def get_input(self):
5454
return self.input_any_
5555

56-
# === Standard Functions ===
57-
def fit(self, X, convert_dtype=True) -> "DummyTestEstimator":
56+
@cuml.internals.reflect(reset=True)
57+
def fit(self, X, convert_dtype=True):
5858
self._set_output_type(X)
5959
self._set_n_features_in(X)
6060
return self
6161

62-
def predict(self, X, convert_dtype=True) -> CumlArray:
62+
@cuml.internals.reflect
63+
def predict(self, X, convert_dtype=True):
6364
return X
6465

65-
def transform(self, X, convert_dtype=False) -> CumlArray:
66+
@cuml.internals.reflect
67+
def transform(self, X, convert_dtype=False):
6668
pass
6769

68-
def fit_transform(self, X, y=None) -> CumlArray:
70+
@cuml.internals.reflect
71+
def fit_transform(self, X, y=None):
6972
return self.fit(X).transform(X)
7073

7174

@@ -280,7 +283,7 @@ def test_return_array(input_arg: str):
280283
X_in = create_input(input_type_X, input_dtype_X, (10, 10), "F")
281284
Y_in = create_input(input_type_Y, input_dtype_Y, (10, 10), "F")
282285

283-
@cuml.internals.api_return_array(input_arg=input_arg, get_output_type=True)
286+
@cuml.internals.reflect(array=input_arg)
284287
def test_func(X, y):
285288
return X
286289

0 commit comments

Comments
 (0)