Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/cuml/cuml/internals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from cuml.internals.api_context_managers import (
in_internal_api,
set_api_output_dtype,
set_api_output_type,
)
from cuml.internals.api_decorators import (
Expand Down
33 changes: 1 addition & 32 deletions python/cuml/cuml/internals/api_context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,6 @@ def set_api_output_type(output_type: str):
GlobalSettings().root_cm.output_type = array_type


def set_api_output_dtype(output_dtype):
assert GlobalSettings().root_cm is not None

# Try to convert any array objects to their type
if output_dtype is not None and cuml.internals.input_utils.is_array_like(
output_dtype
):
output_dtype = cuml.internals.input_utils.determine_array_dtype(
output_dtype
)

assert output_dtype is not None

GlobalSettings().root_cm.output_dtype = output_dtype


class InternalAPIContext(contextlib.ExitStack):
def __init__(self):
super().__init__()
Expand All @@ -89,8 +73,6 @@ def cleanup():
self.enter_context(cupy_using_allocator(rmm_cupy_allocator))
self.prev_output_type = self.enter_context(_using_mirror_output_type())

self.output_dtype = None

# Set the output type to the prev_output_type. If "input", set to None
# to allow inner functions to specify the input
self.output_type = (
Expand Down Expand Up @@ -124,24 +106,14 @@ def __exit__(self, *exc_details):
def push_output_types(self):
try:
old_output_type = self.output_type
old_output_dtype = self.output_dtype

self.output_type = None
self.output_dtype = None

yield

finally:
self.output_type = (
old_output_type
if old_output_type is not None
else self.output_type
)
self.output_dtype = (
old_output_dtype
if old_output_dtype is not None
else self.output_dtype
)


def get_internal_context() -> InternalAPIContext:
Expand Down Expand Up @@ -348,10 +320,7 @@ def convert_to_outputtype(self, ret_val):
and output_type != "input"
), ("Invalid root_cm.output_type: '{}'.").format(output_type)

return ret_val.to_output(
output_type=output_type,
output_dtype=self._context.root_cm.output_dtype,
)
return ret_val.to_output(output_type=output_type)


class ProcessReturnSparseArray(ProcessReturnArray):
Expand Down
35 changes: 1 addition & 34 deletions python/cuml/cuml/internals/api_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ReturnArrayCM,
ReturnGenericCM,
ReturnSparseArrayCM,
set_api_output_dtype,
set_api_output_type,
)
from cuml.internals.constants import CUML_WRAPPED_FLAG
Expand Down Expand Up @@ -103,11 +102,8 @@ def _make_decorator_function(

def decorator_function(
input_arg: str = ...,
target_arg: str = ...,
get_output_type: bool = False,
set_output_type: bool = False,
get_output_dtype: bool = False,
set_output_dtype: bool = False,
set_n_features_in: bool = False,
) -> _DecoratorType:
def decorator_closure(func):
Expand All @@ -124,20 +120,12 @@ def decorator_closure(func):
raise Exception("No self found on function!")

if input_arg is not None and (
set_output_type
or set_output_dtype
or set_n_features_in
or get_output_type
set_output_type or set_n_features_in or get_output_type
):
input_arg_ = _find_arg(sig, input_arg or "X", 0)
else:
input_arg_ = None

if set_output_dtype or (get_output_dtype and not has_self):
target_arg_ = _find_arg(sig, target_arg or "y", 1)
else:
target_arg_ = None

@_wrap_once(func)
def wrapper(*args, **kwargs):
# Wraps the decorated function, executed at runtime.
Expand All @@ -157,22 +145,10 @@ def wrapper(*args, **kwargs):
)
else:
input_val = None
if target_arg_:
target_val = _get_value(
args,
kwargs,
*target_arg_,
accept_lists=accept_lists,
)
else:
target_val = None

if set_output_type:
assert self_val is not None
self_val._set_output_type(input_val)
if set_output_dtype:
assert self_val is not None
self_val._set_target_dtype(target_val)
if set_n_features_in and len(input_val.shape) >= 2:
assert self_val is not None
self_val._set_n_features_in(input_val)
Expand All @@ -186,15 +162,6 @@ def wrapper(*args, **kwargs):

set_api_output_type(out_type)

if get_output_dtype:
if self_val is None:
assert target_val is not None
output_dtype = iu.determine_array_dtype(target_val)
else:
output_dtype = self_val._get_target_dtype()

set_api_output_dtype(output_dtype)

if process_return:
ret = func(*args, **kwargs)
else:
Expand Down
59 changes: 0 additions & 59 deletions python/cuml/cuml/internals/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def __init__(
else output_type
)
self._input_type = None
self.target_dtype = None

nvtx_benchmark = os.getenv("NVTX_BENCHMARK")
if nvtx_benchmark and nvtx_benchmark.lower() == "true":
Expand Down Expand Up @@ -303,47 +302,6 @@ def set_params(self, **params):
setattr(self, key, value)
return self

def _set_base_attributes(
self, output_type=None, target_dtype=None, n_features=None
):
"""
Method to set the base class attributes - output type,
target dtype and n_features. It combines the three different
function calls. It's called in fit function from estimators.

Parameters
--------
output_type : DataFrame (default = None)
Is output_type is passed, aets the output_type on the
dataframe passed
target_dtype : Target column (default = None)
If target_dtype is passed, we call _set_target_dtype
on it
n_features: int or DataFrame (default=None)
If an int is passed, we set it to the number passed
If dataframe, we set it based on the passed df.

Examples
--------

.. code-block:: python

# To set output_type and n_features based on X
self._set_base_attributes(output_type=X, n_features=X)

# To set output_type on X and n_features to 10
self._set_base_attributes(output_type=X, n_features=10)

# To only set target_dtype
self._set_base_attributes(output_type=X, target_dtype=y)
"""
if output_type is not None:
self._set_output_type(output_type)
if target_dtype is not None:
self._set_target_dtype(target_dtype)
if n_features is not None:
self._set_n_features_in(n_features)

def _set_output_type(self, inp):
self._input_type = determine_array_type(inp)

Expand Down Expand Up @@ -372,23 +330,6 @@ class output type and global output type.

return output_type

def _set_target_dtype(self, target):
self.target_dtype = cuml.internals.input_utils.determine_array_dtype(
target
)

def _get_target_dtype(self):
"""
Method to be called by predict/transform methods of
inheriting classifier classes. Returns the appropriate output
dtype depending on the dtype of the target.
"""
try:
out_dtype = self.target_dtype
except AttributeError:
out_dtype = None
return out_dtype

def _set_n_features_in(self, X):
if isinstance(X, int):
self.n_features_in_ = X
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/cuml/neighbors/nearest_neighbors_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ class NearestNeighborsMG(NearestNeighbors):

def get_out_type(self, index, query):
if len(index) > 0:
self._set_base_attributes(output_type=index[0])
self._set_output_type(index[0])
if len(query) > 0:
self._set_base_attributes(output_type=query[0])
self._set_output_type(query[0])

@staticmethod
def gen_local_input(index, index_parts_to_ranks, index_nrows,
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cuml/tsa/arima.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class ARIMA(Base):
super().__init__(handle=handle,
verbose=verbose,
output_type=output_type)
self._set_base_attributes(output_type=endog)
self._set_output_type(endog)

# Check validity of the ARIMA order and seasonal order
p, d, q = order
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cuml/tsa/auto_arima.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class AutoARIMA(Base):
super().__init__(handle=handle,
verbose=verbose,
output_type=output_type)
self._set_base_attributes(output_type=endog)
self._set_output_type(endog)

# Get device array. Float64 only for now.
self.d_y, self.n_obs, self.batch_size, self.dtype \
Expand Down
36 changes: 4 additions & 32 deletions python/cuml/tests/test_cuml_descr_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.array import CumlArray
from cuml.internals.input_utils import (
determine_array_dtype,
determine_array_type,
input_to_cuml_array,
)
Expand Down Expand Up @@ -56,7 +55,8 @@ def get_input(self):

# === Standard Functions ===
def fit(self, X, convert_dtype=True) -> "DummyTestEstimator":
self._set_base_attributes(output_type=X, n_features=X)
self._set_output_type(X)
self._set_n_features_in(X)
return self

def predict(self, X, convert_dtype=True) -> CumlArray:
Expand Down Expand Up @@ -213,7 +213,6 @@ def calc_n_features(shape):
return 1

assert est._input_type == input_type
assert est.target_dtype is None
assert est.n_features_in_ == calc_n_features(input_shape)


Expand Down Expand Up @@ -268,15 +267,8 @@ def test_auto_predict(input_type, base_output_type, global_output_type):


@pytest.mark.parametrize("input_arg", ["X", "y", "bad", ...])
@pytest.mark.parametrize("target_arg", ["X", "y", "bad", ...])
@pytest.mark.parametrize("get_output_type", [True, False])
@pytest.mark.parametrize("get_output_dtype", [True, False])
def test_return_array(
input_arg: str,
target_arg: str,
get_output_type: bool,
get_output_dtype: bool,
):
def test_return_array(input_arg: str, get_output_type: bool):
"""
Test autowrapping on predict that will set target_type
"""
Expand All @@ -288,30 +280,21 @@ def test_return_array(
input_dtype_Y = np.int32

inner_type = "numba"
inner_dtype = np.float16

X_in = create_input(input_type_X, input_dtype_X, (10, 10), "F")
Y_in = create_input(input_type_Y, input_dtype_Y, (10, 10), "F")

def test_func(X, y):
if not get_output_type:
cuml.internals.set_api_output_type(inner_type)

if not get_output_dtype:
cuml.internals.set_api_output_dtype(inner_dtype)

return X

expected_to_fail = (input_arg == "bad" and get_output_type) or (
target_arg == "bad" and get_output_dtype
)
expected_to_fail = input_arg == "bad" and get_output_type

try:
test_func = cuml.internals.api_return_array(
input_arg=input_arg,
target_arg=target_arg,
get_output_type=get_output_type,
get_output_dtype=get_output_dtype,
)(test_func)
except ValueError:
assert expected_to_fail
Expand All @@ -322,7 +305,6 @@ def test_func(X, y):
X_out = test_func(X=X_in, y=Y_in)

target_type = None
target_dtype = None

if not get_output_type:
target_type = inner_type
Expand All @@ -332,14 +314,4 @@ def test_func(X, y):
else:
target_type = input_type_X

if not get_output_dtype:
target_dtype = inner_dtype
else:
if target_arg == "X":
target_dtype = input_dtype_X
else:
target_dtype = input_dtype_Y

assert determine_array_type(X_out) == target_type

assert determine_array_dtype(X_out) == target_dtype