Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 79 additions & 132 deletions mlxtend/classifier/stacking_cv_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sklearn.base import clone
from sklearn.externals import six
from sklearn.model_selection._split import check_cv
from sklearn.utils import safe_indexing


class StackingCVClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
Expand Down Expand Up @@ -182,10 +183,19 @@ def fit(self, X, y, groups=None, sample_weight=None):
# Override shuffle parameter in case of self generated
# cross-validation strategy
final_cv.shuffle = self.shuffle
skf = list(final_cv.split(X, y, groups))

all_model_predictions = np.array([]).reshape(len(y), 0)
for model in self.clfs_:
folds = list(final_cv.split(X, y, groups))

# Handle the case of X being a list of lists
# by converting X into a numpy array
if isinstance(X, list):
X = np.array(X)

meta_features = None
n_folds = final_cv.get_n_splits()
n_models = len(self.clfs_)

for n, model in enumerate(self.clfs_):

if self.verbose > 0:
i = self.clfs_.index(model) + 1
Expand All @@ -200,92 +210,41 @@ def fit(self, X, y, groups=None, sample_weight=None):
if self.verbose > 1:
print(_name_estimators((model,))[0][1])

if not self.use_probas:
single_model_prediction = np.array([]).reshape(0, 1)
else:
single_model_prediction = np.array([]).reshape(0, len(set(y)))
for num, (train_indices, test_indices) in enumerate(folds):

for num, (train_index, test_index) in enumerate(skf):
X_train = safe_indexing(X, train_indices)
y_train = safe_indexing(y, train_indices)

if self.verbose > 0:
print("Training and fitting fold %d of %d..." %
((num + 1), final_cv.get_n_splits()))

try:
if sample_weight is None:
model.fit(X[train_index], y[train_index])
else:
model.fit(X[train_index], y[train_index],
sample_weight=sample_weight[train_index])
except TypeError as e:

if str(e).startswith('A sparse matrix was passed,'
' but dense'
' data is required'):
sparse_estimator_message = (
"\nYou are likely getting this error"
" because one of the"
" estimators"
" does not support sparse matrix input.")
else:
sparse_estimator_message = ''

raise TypeError(str(e) + sparse_estimator_message +
'\nPlease check that X and y'
'are NumPy arrays. If X and y are lists'
' of lists,\ntry passing them as'
' numpy.array(X)'
' and numpy.array(y).')
except KeyError as e:

raise KeyError(str(e) + '\nPlease check that X and y'
' are NumPy arrays. If X and y are pandas'
' DataFrames,\ntry passing them as'
' X.values'
' and y.values.')
((num + 1), n_folds))

if not self.use_probas:
prediction = model.predict(X[test_index])
prediction = prediction.reshape(prediction.shape[0], 1)
if sample_weight is None:
model.fit(X_train, y_train)
else:
prediction = model.predict_proba(X[test_index])
single_model_prediction = np.vstack([single_model_prediction.
astype(prediction.dtype),
prediction])
w = safe_indexing(sample_weight, train_indices)
model.fit(X_train, y_train, sample_weight=w)

all_model_predictions = np.hstack([all_model_predictions.
astype(single_model_prediction.
dtype),
single_model_prediction])
X_test = safe_indexing(X, test_indices)
if not self.use_probas:
prediction = model.predict(X_test)[:, np.newaxis]
else:
prediction = model.predict_proba(X_test)

if meta_features is None:
# First run, use prediction to get the number of classes
n_classes = prediction.shape[1]
meta_features_shape = (X.shape[0], n_classes * n_models)
meta_features = np.empty(shape=meta_features_shape)
meta_features[np.array(test_indices)[:, np.newaxis],
np.arange(n_classes)] = prediction
else:
row_idx = np.array(test_indices)[:, np.newaxis]
col_idx = np.arange(n_classes) + n * n_classes
meta_features[row_idx, col_idx] = prediction

if self.store_train_meta_features:
# Store the meta features in the order of the
# original X,y arrays
reodered_indices = np.array([]).astype(y.dtype)
for train_index, test_index in skf:
reodered_indices = np.concatenate((reodered_indices,
test_index))
self.train_meta_features_ = all_model_predictions[np.argsort(
reodered_indices)]

# We have to shuffle the labels in the same order as we generated
# predictions during CV (we kinda shuffled them when we did
# Stratified CV).
# We also do the same with the features (we will need this only IF
# use_features_in_secondary is True)
reordered_labels = np.array([]).astype(y.dtype)
reordered_features = np.array([]).reshape((0, X.shape[1]))\
.astype(X.dtype)
for train_index, test_index in skf:
reordered_labels = np.concatenate((reordered_labels,
y[test_index]))

if sparse.issparse(X):
reordered_features = sparse.vstack((reordered_features,
X[test_index]))
else:
reordered_features = np.concatenate((reordered_features,
X[test_index]))
self.train_meta_features_ = meta_features

# Fit the base models correctly this time using ALL the training set
for model in self.clfs_:
Expand All @@ -295,18 +254,16 @@ def fit(self, X, y, groups=None, sample_weight=None):
model.fit(X, y, sample_weight=sample_weight)

# Fit the secondary model
if not self.use_features_in_secondary:
meta_features = all_model_predictions
elif sparse.issparse(X):
meta_features = sparse.hstack((reordered_features,
all_model_predictions))
else:
meta_features = np.hstack((reordered_features,
all_model_predictions))
if self.use_features_in_secondary:
meta_features = self._stack_first_level_features(
X,
meta_features
)

if sample_weight is None:
self.meta_clf_.fit(meta_features, reordered_labels)
self.meta_clf_.fit(meta_features, y)
else:
self.meta_clf_.fit(meta_features, reordered_labels,
self.meta_clf_.fit(meta_features, y,
sample_weight=sample_weight)

return self
Expand Down Expand Up @@ -347,20 +304,35 @@ def predict_meta_features(self, X):
Returns the meta-features for test data.

"""
check_is_fitted(self, 'clfs_')
all_model_predictions = np.array([]).reshape(len(X), 0)
check_is_fitted(self, ['clfs_', 'meta_clf_'])

per_model_preds = []

for model in self.clfs_:
if not self.use_probas:
single_model_prediction = model.predict(X)
single_model_prediction = single_model_prediction\
.reshape(single_model_prediction.shape[0], 1)
prediction = model.predict(X)[:, np.newaxis]
else:
single_model_prediction = model.predict_proba(X)
all_model_predictions = np.hstack((all_model_predictions.
astype(single_model_prediction
.dtype),
single_model_prediction))
return all_model_predictions
prediction = model.predict_proba(X)

per_model_preds.append(prediction)

return np.hstack(per_model_preds)

def _stack_first_level_features(self, X, meta_features):
if sparse.issparse(X):
stack_fn = sparse.hstack
else:
stack_fn = np.hstack

return stack_fn((X, meta_features))

def _do_predict(self, X, predict_fn):
meta_features = self.predict_meta_features(X)

if self.use_features_in_secondary:
meta_features = self._stack_first_level_features(X, meta_features)

return predict_fn(meta_features)

def predict(self, X):
""" Predict target values for X.
Expand All @@ -377,16 +349,9 @@ def predict(self, X):
Predicted class labels.

"""
check_is_fitted(self, 'clfs_')
all_model_predictions = self.predict_meta_features(X)
if not self.use_features_in_secondary:
return self.meta_clf_.predict(all_model_predictions)
elif sparse.issparse(X):
return self.meta_clf_.predict(
sparse.hstack((X, all_model_predictions)))
else:
return self.meta_clf_.predict(
np.hstack((X, all_model_predictions)))
check_is_fitted(self, ['clfs_', 'meta_clf_'])

return self._do_predict(X, self.meta_clf_.predict)

def predict_proba(self, X):
""" Predict class probabilities for X.
Expand All @@ -403,24 +368,6 @@ def predict_proba(self, X):
Probability for each class per sample.

"""
check_is_fitted(self, 'clfs_')
all_model_predictions = np.array([]).reshape(len(X), 0)
for model in self.clfs_:
if not self.use_probas:
single_model_prediction = model.predict(X)
single_model_prediction = single_model_prediction\
.reshape(single_model_prediction.shape[0], 1)
else:
single_model_prediction = model.predict_proba(X)
all_model_predictions = np.hstack((all_model_predictions.
astype(single_model_prediction.
dtype),
single_model_prediction))
if not self.use_features_in_secondary:
return self.meta_clf_.predict_proba(all_model_predictions)
elif sparse.issparse(X):
self.meta_clf_\
.predict_proba(sparse.hstack((X, all_model_predictions)))
else:
return self.meta_clf_\
.predict_proba(np.hstack((X, all_model_predictions)))
check_is_fitted(self, ['clfs_', 'meta_clf_'])

return self._do_predict(X, self.meta_clf_.predict_proba)
75 changes: 40 additions & 35 deletions mlxtend/classifier/tests/test_stacking_cv_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,39 +329,6 @@ def test_verbose():
sclf.fit(X_iris, y_iris)


def test_list_of_lists():
X_list = [i for i in X_iris]
meta = LogisticRegression(multi_class='ovr', solver='liblinear')
clf1 = RandomForestClassifier(n_estimators=10)
clf2 = GaussianNB()
sclf = StackingCVClassifier(classifiers=[clf1, clf2],
use_probas=True,
meta_classifier=meta,
shuffle=False,
verbose=0)

try:
sclf.fit(X_list, y_iris)
except TypeError as e:
assert 'are NumPy arrays. If X and y are lists' in str(e)


def test_pandas():
X_df = pd.DataFrame(X_iris)
meta = LogisticRegression(multi_class='ovr', solver='liblinear')
clf1 = RandomForestClassifier(n_estimators=10)
clf2 = GaussianNB()
sclf = StackingCVClassifier(classifiers=[clf1, clf2],
use_probas=True,
meta_classifier=meta,
shuffle=False,
verbose=0)
try:
sclf.fit(X_df, y_iris)
except KeyError as e:
assert 'are NumPy arrays. If X and y are pandas DataFrames' in str(e)


def test_get_params():
clf1 = KNeighborsClassifier(n_neighbors=1)
clf2 = RandomForestClassifier(random_state=1)
Expand Down Expand Up @@ -493,8 +460,8 @@ def test_sparse_inputs_with_features_in_secondary():
stclf = StackingCVClassifier(classifiers=[rf, rf],
meta_classifier=lr,
use_features_in_secondary=True)
X_train, X_test, y_train, y_test = train_test_split(X_breast, y_breast,
test_size=0.3)
X_train, X_test, y_train, y_test = train_test_split(X_breast, y_breast,
test_size=0.3)

# dense
stclf.fit(X_train, y_train)
Expand All @@ -505,3 +472,41 @@ def test_sparse_inputs_with_features_in_secondary():
stclf.fit(sparse.csr_matrix(X_train), y_train)
assert round(stclf.score(X_train, y_train), 2) == 0.99, \
round(stclf.score(X_train, y_train), 2)


def test_works_with_df_if_fold_indexes_missing():
"""This is a regression test to make sure fitting will still work even if
training data has ids that cannot be indexed using the indexes from the cv
(e.g. skf)

Some possibilities:
+ Output of the folds are not neatly consecutive (i.e. [341, 345, 543, ...]
instead of [0, 1, ... n])
+ Indexes just start from some number greater than the size of the input
(see test case)

Training data sometimes has ids that carry other information, and selection
of rows based on cv should not break.

This is fixed in the code using `safe_indexing`
"""

np.random.seed(123)
rf = RandomForestClassifier(n_estimators=10)
lr = LogisticRegression(multi_class='ovr', solver='liblinear')
stclf = StackingCVClassifier(classifiers=[rf, rf],
meta_classifier=lr,
use_features_in_secondary=True)

X_modded = pd.DataFrame(X_breast,
index=np.arange(X_breast.shape[0]) + 1000)
y_modded = pd.Series(y_breast,
index=np.arange(y_breast.shape[0]) + 1000)

X_train, X_test, y_train, y_test = train_test_split(X_modded, y_modded,
test_size=0.3)

# dense
stclf.fit(X_train, y_train)
assert round(stclf.score(X_train, y_train), 2) == 0.99, \
round(stclf.score(X_train, y_train), 2)