diff --git a/mlxtend/classifier/stacking_cv_classification.py b/mlxtend/classifier/stacking_cv_classification.py index 5633f5dc4..2eb329492 100644 --- a/mlxtend/classifier/stacking_cv_classification.py +++ b/mlxtend/classifier/stacking_cv_classification.py @@ -19,6 +19,7 @@ from sklearn.base import clone from sklearn.externals import six from sklearn.model_selection._split import check_cv +from sklearn.utils import safe_indexing class StackingCVClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): @@ -182,10 +183,19 @@ def fit(self, X, y, groups=None, sample_weight=None): # Override shuffle parameter in case of self generated # cross-validation strategy final_cv.shuffle = self.shuffle - skf = list(final_cv.split(X, y, groups)) - all_model_predictions = np.array([]).reshape(len(y), 0) - for model in self.clfs_: + folds = list(final_cv.split(X, y, groups)) + + # Handle the case of X being a list of lists + # by converting X into a numpy array + if isinstance(X, list): + X = np.array(X) + + meta_features = None + n_folds = final_cv.get_n_splits() + n_models = len(self.clfs_) + + for n, model in enumerate(self.clfs_): if self.verbose > 0: i = self.clfs_.index(model) + 1 @@ -200,92 +210,41 @@ def fit(self, X, y, groups=None, sample_weight=None): if self.verbose > 1: print(_name_estimators((model,))[0][1]) - if not self.use_probas: - single_model_prediction = np.array([]).reshape(0, 1) - else: - single_model_prediction = np.array([]).reshape(0, len(set(y))) + for num, (train_indices, test_indices) in enumerate(folds): - for num, (train_index, test_index) in enumerate(skf): + X_train = safe_indexing(X, train_indices) + y_train = safe_indexing(y, train_indices) if self.verbose > 0: print("Training and fitting fold %d of %d..." % - ((num + 1), final_cv.get_n_splits())) - - try: - if sample_weight is None: - model.fit(X[train_index], y[train_index]) - else: - model.fit(X[train_index], y[train_index], - sample_weight=sample_weight[train_index]) - except TypeError as e: - - if str(e).startswith('A sparse matrix was passed,' - ' but dense' - ' data is required'): - sparse_estimator_message = ( - "\nYou are likely getting this error" - " because one of the" - " estimators" - " does not support sparse matrix input.") - else: - sparse_estimator_message = '' - - raise TypeError(str(e) + sparse_estimator_message + - '\nPlease check that X and y' - 'are NumPy arrays. If X and y are lists' - ' of lists,\ntry passing them as' - ' numpy.array(X)' - ' and numpy.array(y).') - except KeyError as e: - - raise KeyError(str(e) + '\nPlease check that X and y' - ' are NumPy arrays. If X and y are pandas' - ' DataFrames,\ntry passing them as' - ' X.values' - ' and y.values.') + ((num + 1), n_folds)) - if not self.use_probas: - prediction = model.predict(X[test_index]) - prediction = prediction.reshape(prediction.shape[0], 1) + if sample_weight is None: + model.fit(X_train, y_train) else: - prediction = model.predict_proba(X[test_index]) - single_model_prediction = np.vstack([single_model_prediction. - astype(prediction.dtype), - prediction]) + w = safe_indexing(sample_weight, train_indices) + model.fit(X_train, y_train, sample_weight=w) - all_model_predictions = np.hstack([all_model_predictions. - astype(single_model_prediction. - dtype), - single_model_prediction]) + X_test = safe_indexing(X, test_indices) + if not self.use_probas: + prediction = model.predict(X_test)[:, np.newaxis] + else: + prediction = model.predict_proba(X_test) + + if meta_features is None: + # First run, use prediction to get the number of classes + n_classes = prediction.shape[1] + meta_features_shape = (X.shape[0], n_classes * n_models) + meta_features = np.empty(shape=meta_features_shape) + meta_features[np.array(test_indices)[:, np.newaxis], + np.arange(n_classes)] = prediction + else: + row_idx = np.array(test_indices)[:, np.newaxis] + col_idx = np.arange(n_classes) + n * n_classes + meta_features[row_idx, col_idx] = prediction if self.store_train_meta_features: - # Store the meta features in the order of the - # original X,y arrays - reodered_indices = np.array([]).astype(y.dtype) - for train_index, test_index in skf: - reodered_indices = np.concatenate((reodered_indices, - test_index)) - self.train_meta_features_ = all_model_predictions[np.argsort( - reodered_indices)] - - # We have to shuffle the labels in the same order as we generated - # predictions during CV (we kinda shuffled them when we did - # Stratified CV). - # We also do the same with the features (we will need this only IF - # use_features_in_secondary is True) - reordered_labels = np.array([]).astype(y.dtype) - reordered_features = np.array([]).reshape((0, X.shape[1]))\ - .astype(X.dtype) - for train_index, test_index in skf: - reordered_labels = np.concatenate((reordered_labels, - y[test_index])) - - if sparse.issparse(X): - reordered_features = sparse.vstack((reordered_features, - X[test_index])) - else: - reordered_features = np.concatenate((reordered_features, - X[test_index])) + self.train_meta_features_ = meta_features # Fit the base models correctly this time using ALL the training set for model in self.clfs_: @@ -295,18 +254,16 @@ def fit(self, X, y, groups=None, sample_weight=None): model.fit(X, y, sample_weight=sample_weight) # Fit the secondary model - if not self.use_features_in_secondary: - meta_features = all_model_predictions - elif sparse.issparse(X): - meta_features = sparse.hstack((reordered_features, - all_model_predictions)) - else: - meta_features = np.hstack((reordered_features, - all_model_predictions)) + if self.use_features_in_secondary: + meta_features = self._stack_first_level_features( + X, + meta_features + ) + if sample_weight is None: - self.meta_clf_.fit(meta_features, reordered_labels) + self.meta_clf_.fit(meta_features, y) else: - self.meta_clf_.fit(meta_features, reordered_labels, + self.meta_clf_.fit(meta_features, y, sample_weight=sample_weight) return self @@ -347,20 +304,35 @@ def predict_meta_features(self, X): Returns the meta-features for test data. """ - check_is_fitted(self, 'clfs_') - all_model_predictions = np.array([]).reshape(len(X), 0) + check_is_fitted(self, ['clfs_', 'meta_clf_']) + + per_model_preds = [] + for model in self.clfs_: if not self.use_probas: - single_model_prediction = model.predict(X) - single_model_prediction = single_model_prediction\ - .reshape(single_model_prediction.shape[0], 1) + prediction = model.predict(X)[:, np.newaxis] else: - single_model_prediction = model.predict_proba(X) - all_model_predictions = np.hstack((all_model_predictions. - astype(single_model_prediction - .dtype), - single_model_prediction)) - return all_model_predictions + prediction = model.predict_proba(X) + + per_model_preds.append(prediction) + + return np.hstack(per_model_preds) + + def _stack_first_level_features(self, X, meta_features): + if sparse.issparse(X): + stack_fn = sparse.hstack + else: + stack_fn = np.hstack + + return stack_fn((X, meta_features)) + + def _do_predict(self, X, predict_fn): + meta_features = self.predict_meta_features(X) + + if self.use_features_in_secondary: + meta_features = self._stack_first_level_features(X, meta_features) + + return predict_fn(meta_features) def predict(self, X): """ Predict target values for X. @@ -377,16 +349,9 @@ def predict(self, X): Predicted class labels. """ - check_is_fitted(self, 'clfs_') - all_model_predictions = self.predict_meta_features(X) - if not self.use_features_in_secondary: - return self.meta_clf_.predict(all_model_predictions) - elif sparse.issparse(X): - return self.meta_clf_.predict( - sparse.hstack((X, all_model_predictions))) - else: - return self.meta_clf_.predict( - np.hstack((X, all_model_predictions))) + check_is_fitted(self, ['clfs_', 'meta_clf_']) + + return self._do_predict(X, self.meta_clf_.predict) def predict_proba(self, X): """ Predict class probabilities for X. @@ -403,24 +368,6 @@ def predict_proba(self, X): Probability for each class per sample. """ - check_is_fitted(self, 'clfs_') - all_model_predictions = np.array([]).reshape(len(X), 0) - for model in self.clfs_: - if not self.use_probas: - single_model_prediction = model.predict(X) - single_model_prediction = single_model_prediction\ - .reshape(single_model_prediction.shape[0], 1) - else: - single_model_prediction = model.predict_proba(X) - all_model_predictions = np.hstack((all_model_predictions. - astype(single_model_prediction. - dtype), - single_model_prediction)) - if not self.use_features_in_secondary: - return self.meta_clf_.predict_proba(all_model_predictions) - elif sparse.issparse(X): - self.meta_clf_\ - .predict_proba(sparse.hstack((X, all_model_predictions))) - else: - return self.meta_clf_\ - .predict_proba(np.hstack((X, all_model_predictions))) + check_is_fitted(self, ['clfs_', 'meta_clf_']) + + return self._do_predict(X, self.meta_clf_.predict_proba) diff --git a/mlxtend/classifier/tests/test_stacking_cv_classifier.py b/mlxtend/classifier/tests/test_stacking_cv_classifier.py index 957f216d5..735fd890d 100644 --- a/mlxtend/classifier/tests/test_stacking_cv_classifier.py +++ b/mlxtend/classifier/tests/test_stacking_cv_classifier.py @@ -329,39 +329,6 @@ def test_verbose(): sclf.fit(X_iris, y_iris) -def test_list_of_lists(): - X_list = [i for i in X_iris] - meta = LogisticRegression(multi_class='ovr', solver='liblinear') - clf1 = RandomForestClassifier(n_estimators=10) - clf2 = GaussianNB() - sclf = StackingCVClassifier(classifiers=[clf1, clf2], - use_probas=True, - meta_classifier=meta, - shuffle=False, - verbose=0) - - try: - sclf.fit(X_list, y_iris) - except TypeError as e: - assert 'are NumPy arrays. If X and y are lists' in str(e) - - -def test_pandas(): - X_df = pd.DataFrame(X_iris) - meta = LogisticRegression(multi_class='ovr', solver='liblinear') - clf1 = RandomForestClassifier(n_estimators=10) - clf2 = GaussianNB() - sclf = StackingCVClassifier(classifiers=[clf1, clf2], - use_probas=True, - meta_classifier=meta, - shuffle=False, - verbose=0) - try: - sclf.fit(X_df, y_iris) - except KeyError as e: - assert 'are NumPy arrays. If X and y are pandas DataFrames' in str(e) - - def test_get_params(): clf1 = KNeighborsClassifier(n_neighbors=1) clf2 = RandomForestClassifier(random_state=1) @@ -493,8 +460,8 @@ def test_sparse_inputs_with_features_in_secondary(): stclf = StackingCVClassifier(classifiers=[rf, rf], meta_classifier=lr, use_features_in_secondary=True) - X_train, X_test, y_train, y_test = train_test_split(X_breast, y_breast, - test_size=0.3) + X_train, X_test, y_train, y_test = train_test_split(X_breast, y_breast, + test_size=0.3) # dense stclf.fit(X_train, y_train) @@ -505,3 +472,41 @@ def test_sparse_inputs_with_features_in_secondary(): stclf.fit(sparse.csr_matrix(X_train), y_train) assert round(stclf.score(X_train, y_train), 2) == 0.99, \ round(stclf.score(X_train, y_train), 2) + + +def test_works_with_df_if_fold_indexes_missing(): + """This is a regression test to make sure fitting will still work even if + training data has ids that cannot be indexed using the indexes from the cv + (e.g. skf) + + Some possibilities: + + Output of the folds are not neatly consecutive (i.e. [341, 345, 543, ...] + instead of [0, 1, ... n]) + + Indexes just start from some number greater than the size of the input + (see test case) + + Training data sometimes has ids that carry other information, and selection + of rows based on cv should not break. + + This is fixed in the code using `safe_indexing` + """ + + np.random.seed(123) + rf = RandomForestClassifier(n_estimators=10) + lr = LogisticRegression(multi_class='ovr', solver='liblinear') + stclf = StackingCVClassifier(classifiers=[rf, rf], + meta_classifier=lr, + use_features_in_secondary=True) + + X_modded = pd.DataFrame(X_breast, + index=np.arange(X_breast.shape[0]) + 1000) + y_modded = pd.Series(y_breast, + index=np.arange(y_breast.shape[0]) + 1000) + + X_train, X_test, y_train, y_test = train_test_split(X_modded, y_modded, + test_size=0.3) + + # dense + stclf.fit(X_train, y_train) + assert round(stclf.score(X_train, y_train), 2) == 0.99, \ + round(stclf.score(X_train, y_train), 2)