From 06156df2142b6c8762f4b46e5aeacae4cf4ba0e9 Mon Sep 17 00:00:00 2001 From: ackerleytng Date: Thu, 21 Mar 2019 00:13:11 +0800 Subject: [PATCH 1/6] Refactor StackingCVClassifer and use safe_indexing --- .../classifier/stacking_cv_classification.py | 176 +++++++++--------- 1 file changed, 92 insertions(+), 84 deletions(-) diff --git a/mlxtend/classifier/stacking_cv_classification.py b/mlxtend/classifier/stacking_cv_classification.py index 5633f5dc4..9c0b9a8f1 100644 --- a/mlxtend/classifier/stacking_cv_classification.py +++ b/mlxtend/classifier/stacking_cv_classification.py @@ -19,6 +19,7 @@ from sklearn.base import clone from sklearn.externals import six from sklearn.model_selection._split import check_cv +from sklearn.utils import safe_indexing class StackingCVClassifier(BaseEstimator, ClassifierMixin, TransformerMixin): @@ -141,6 +142,59 @@ def __init__(self, classifiers, meta_classifier, self.store_train_meta_features = store_train_meta_features self.use_clones = use_clones + def _fit_fold(self, model, X, y, sample_weight, index, n_splits): + if self.verbose > 0: + print("Training and fitting fold %d of %d..." % + ((index + 1), n_splits)) + + try: + if sample_weight is None: + model.fit(X, y) + else: + model.fit(X, y, + sample_weight=sample_weight) + except TypeError as e: + + if str(e).startswith('A sparse matrix was passed,' + ' but dense' + ' data is required'): + sparse_estimator_message = ( + "\nYou are likely getting this error" + " because one of the" + " estimators" + " does not support sparse matrix input.") + else: + sparse_estimator_message = '' + + raise TypeError(str(e) + sparse_estimator_message + + '\nPlease check that X and y' + 'are NumPy arrays. If X and y are lists' + ' of lists,\ntry passing them as' + ' numpy.array(X)' + ' and numpy.array(y).') + except KeyError as e: + + raise KeyError(str(e) + '\nPlease check that X and y' + ' are NumPy arrays. If X and y are pandas' + ' DataFrames,\ntry passing them as' + ' X.values' + ' and y.values.') + + return model + + def _reorder_with_cv(self, arr, cv): + """Reorders and selects indices from arr using test indices of cv""" + + rows = [safe_indexing(arr, test_indices) + for _, test_indices in cv] + + if sparse.issparse(arr): + stack_fn = sparse.vstack + else: + stack_fn = np.concatenate + + return stack_fn(rows) + def fit(self, X, y, groups=None, sample_weight=None): """ Fit ensemble classifers and the meta-classifier. @@ -184,7 +238,8 @@ def fit(self, X, y, groups=None, sample_weight=None): final_cv.shuffle = self.shuffle skf = list(final_cv.split(X, y, groups)) - all_model_predictions = np.array([]).reshape(len(y), 0) + per_model_predictions = [] + for model in self.clfs_: if self.verbose > 0: @@ -200,92 +255,40 @@ def fit(self, X, y, groups=None, sample_weight=None): if self.verbose > 1: print(_name_estimators((model,))[0][1]) - if not self.use_probas: - single_model_prediction = np.array([]).reshape(0, 1) - else: - single_model_prediction = np.array([]).reshape(0, len(set(y))) - - for num, (train_index, test_index) in enumerate(skf): - - if self.verbose > 0: - print("Training and fitting fold %d of %d..." % - ((num + 1), final_cv.get_n_splits())) - - try: - if sample_weight is None: - model.fit(X[train_index], y[train_index]) - else: - model.fit(X[train_index], y[train_index], - sample_weight=sample_weight[train_index]) - except TypeError as e: - - if str(e).startswith('A sparse matrix was passed,' - ' but dense' - ' data is required'): - sparse_estimator_message = ( - "\nYou are likely getting this error" - " because one of the" - " estimators" - " does not support sparse matrix input.") - else: - sparse_estimator_message = '' - - raise TypeError(str(e) + sparse_estimator_message + - '\nPlease check that X and y' - 'are NumPy arrays. If X and y are lists' - ' of lists,\ntry passing them as' - ' numpy.array(X)' - ' and numpy.array(y).') - except KeyError as e: - - raise KeyError(str(e) + '\nPlease check that X and y' - ' are NumPy arrays. If X and y are pandas' - ' DataFrames,\ntry passing them as' - ' X.values' - ' and y.values.') + per_fold_predictions = [] + + for num, (train_indices, test_indices) in enumerate(skf): + + X_train = safe_indexing(X, train_indices) + y_train = safe_indexing(y, train_indices) + model = self._fit_fold( + model, + X_train, + y_train, + (safe_indexing(sample_weight, train_indices) + if sample_weight is not None else None), + num, + final_cv.get_n_splits() + ) + X_test = safe_indexing(X, test_indices) if not self.use_probas: - prediction = model.predict(X[test_index]) + prediction = model.predict(X_test) prediction = prediction.reshape(prediction.shape[0], 1) else: - prediction = model.predict_proba(X[test_index]) - single_model_prediction = np.vstack([single_model_prediction. - astype(prediction.dtype), - prediction]) + prediction = model.predict_proba(X_test) + per_fold_predictions.append(prediction) - all_model_predictions = np.hstack([all_model_predictions. - astype(single_model_prediction. - dtype), - single_model_prediction]) + per_model_predictions.append(np.vstack(per_fold_predictions)) + + all_model_predictions = np.hstack(per_model_predictions) if self.store_train_meta_features: # Store the meta features in the order of the - # original X,y arrays - reodered_indices = np.array([]).astype(y.dtype) - for train_index, test_index in skf: - reodered_indices = np.concatenate((reodered_indices, - test_index)) - self.train_meta_features_ = all_model_predictions[np.argsort( - reodered_indices)] - - # We have to shuffle the labels in the same order as we generated - # predictions during CV (we kinda shuffled them when we did - # Stratified CV). - # We also do the same with the features (we will need this only IF - # use_features_in_secondary is True) - reordered_labels = np.array([]).astype(y.dtype) - reordered_features = np.array([]).reshape((0, X.shape[1]))\ - .astype(X.dtype) - for train_index, test_index in skf: - reordered_labels = np.concatenate((reordered_labels, - y[test_index])) - - if sparse.issparse(X): - reordered_features = sparse.vstack((reordered_features, - X[test_index])) - else: - reordered_features = np.concatenate((reordered_features, - X[test_index])) + # original X, y arrays + all_test_indices = np.concatenate([i for _, i in skf]) + self.train_meta_features_ = \ + all_model_predictions[np.argsort(all_test_indices)] # Fit the base models correctly this time using ALL the training set for model in self.clfs_: @@ -297,12 +300,17 @@ def fit(self, X, y, groups=None, sample_weight=None): # Fit the secondary model if not self.use_features_in_secondary: meta_features = all_model_predictions - elif sparse.issparse(X): - meta_features = sparse.hstack((reordered_features, - all_model_predictions)) else: - meta_features = np.hstack((reordered_features, - all_model_predictions)) + if sparse.issparse(X): + stack_fn = sparse.hstack + else: + stack_fn = np.hstack + + reordered_features = self._reorder_with_cv(X, skf) + meta_features = stack_fn((reordered_features, + all_model_predictions)) + + reordered_labels = self._reorder_with_cv(y, skf) if sample_weight is None: self.meta_clf_.fit(meta_features, reordered_labels) else: From 38dccd000cdde02f89d8d0b1568d1d2a52654fc9 Mon Sep 17 00:00:00 2001 From: ackerleytng Date: Sat, 23 Mar 2019 17:17:23 +0800 Subject: [PATCH 2/6] Fix combining of probabilities by inserting 0 probability columns --- .../classifier/stacking_cv_classification.py | 72 ++++++++++++++++--- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/mlxtend/classifier/stacking_cv_classification.py b/mlxtend/classifier/stacking_cv_classification.py index 9c0b9a8f1..c82a7f697 100644 --- a/mlxtend/classifier/stacking_cv_classification.py +++ b/mlxtend/classifier/stacking_cv_classification.py @@ -195,6 +195,52 @@ def _reorder_with_cv(self, arr, cv): return stack_fn(rows) + def _all_elements_equal(self, lst): + # http://stackoverflow.com/q/3844948/ + return not lst or lst.count(lst[0]) == len(lst) + + def _combine_probas(self, probas, ys): + """Takes probabilities for each fold, and combines them + + This is not a simple np.vstack of probabilities because in cases of + poor data, such as when only one instance of a certain class exists in + a training set, StratifiedKFold will result in two iterations (one + fold) having this class among the labels, and the last iteration will + not have this class in the training set. + + fit() on the first level models may result in a differing number of + columns among probas, hence causing vstack to fail. + + The workaround here is to insert a column (probability of 0) for the + missing class in that fold before combining the probabilities. + """ + + n_cols = [p.shape[1] for p in probas] + if not self._all_elements_equal(n_cols): + # Assumes that the classes in the output of predict_proba are + # sorted the way numpy sorts + y_all = np.sort(np.unique(np.concatenate(ys))) + unique_values = [np.sort(np.unique(y)) for y in ys] + missing_values = [np.setdiff1d(y_all, uv) + for uv in unique_values] + + missing_values_idxs = [] + for arr in missing_values: + missing_values_idxs.append( + [np.where(y_all == v)[0][0] + for v in arr] + ) + + new_probas = [] + for p, idxs in zip(probas, missing_values_idxs): + for i in idxs: + p = np.insert(p, i, 0, axis=1) + new_probas.append(p) + + probas = new_probas + + return np.vstack(probas) + def fit(self, X, y, groups=None, sample_weight=None): """ Fit ensemble classifers and the meta-classifier. @@ -236,9 +282,10 @@ def fit(self, X, y, groups=None, sample_weight=None): # Override shuffle parameter in case of self generated # cross-validation strategy final_cv.shuffle = self.shuffle - skf = list(final_cv.split(X, y, groups)) - per_model_predictions = [] + folds = list(final_cv.split(X, y, groups)) + + per_model_preds = [] for model in self.clfs_: @@ -255,12 +302,14 @@ def fit(self, X, y, groups=None, sample_weight=None): if self.verbose > 1: print(_name_estimators((model,))[0][1]) - per_fold_predictions = [] + per_fold_preds = [] + ys = [] - for num, (train_indices, test_indices) in enumerate(skf): + for num, (train_indices, test_indices) in enumerate(folds): X_train = safe_indexing(X, train_indices) y_train = safe_indexing(y, train_indices) + model = self._fit_fold( model, X_train, @@ -277,16 +326,19 @@ def fit(self, X, y, groups=None, sample_weight=None): prediction = prediction.reshape(prediction.shape[0], 1) else: prediction = model.predict_proba(X_test) - per_fold_predictions.append(prediction) - per_model_predictions.append(np.vstack(per_fold_predictions)) + per_fold_preds.append(prediction) + ys.append(y_train) + + all_folds_preds = self._combine_probas(per_fold_preds, ys) + per_model_preds.append(all_folds_preds) - all_model_predictions = np.hstack(per_model_predictions) + all_model_predictions = np.hstack(per_model_preds) if self.store_train_meta_features: # Store the meta features in the order of the # original X, y arrays - all_test_indices = np.concatenate([i for _, i in skf]) + all_test_indices = np.concatenate([i for _, i in folds]) self.train_meta_features_ = \ all_model_predictions[np.argsort(all_test_indices)] @@ -306,11 +358,11 @@ def fit(self, X, y, groups=None, sample_weight=None): else: stack_fn = np.hstack - reordered_features = self._reorder_with_cv(X, skf) + reordered_features = self._reorder_with_cv(X, folds) meta_features = stack_fn((reordered_features, all_model_predictions)) - reordered_labels = self._reorder_with_cv(y, skf) + reordered_labels = self._reorder_with_cv(y, folds) if sample_weight is None: self.meta_clf_.fit(meta_features, reordered_labels) else: From 09ce8e71182929b45e0857f1611ea2c8d7fd4b2d Mon Sep 17 00:00:00 2001 From: ackerleytng Date: Sun, 24 Mar 2019 08:19:24 +0800 Subject: [PATCH 3/6] Refactor prediction functions --- .../classifier/stacking_cv_classification.py | 138 +++++------------- 1 file changed, 36 insertions(+), 102 deletions(-) diff --git a/mlxtend/classifier/stacking_cv_classification.py b/mlxtend/classifier/stacking_cv_classification.py index c82a7f697..2646af39b 100644 --- a/mlxtend/classifier/stacking_cv_classification.py +++ b/mlxtend/classifier/stacking_cv_classification.py @@ -195,52 +195,6 @@ def _reorder_with_cv(self, arr, cv): return stack_fn(rows) - def _all_elements_equal(self, lst): - # http://stackoverflow.com/q/3844948/ - return not lst or lst.count(lst[0]) == len(lst) - - def _combine_probas(self, probas, ys): - """Takes probabilities for each fold, and combines them - - This is not a simple np.vstack of probabilities because in cases of - poor data, such as when only one instance of a certain class exists in - a training set, StratifiedKFold will result in two iterations (one - fold) having this class among the labels, and the last iteration will - not have this class in the training set. - - fit() on the first level models may result in a differing number of - columns among probas, hence causing vstack to fail. - - The workaround here is to insert a column (probability of 0) for the - missing class in that fold before combining the probabilities. - """ - - n_cols = [p.shape[1] for p in probas] - if not self._all_elements_equal(n_cols): - # Assumes that the classes in the output of predict_proba are - # sorted the way numpy sorts - y_all = np.sort(np.unique(np.concatenate(ys))) - unique_values = [np.sort(np.unique(y)) for y in ys] - missing_values = [np.setdiff1d(y_all, uv) - for uv in unique_values] - - missing_values_idxs = [] - for arr in missing_values: - missing_values_idxs.append( - [np.where(y_all == v)[0][0] - for v in arr] - ) - - new_probas = [] - for p, idxs in zip(probas, missing_values_idxs): - for i in idxs: - p = np.insert(p, i, 0, axis=1) - new_probas.append(p) - - probas = new_probas - - return np.vstack(probas) - def fit(self, X, y, groups=None, sample_weight=None): """ Fit ensemble classifers and the meta-classifier. @@ -303,7 +257,6 @@ def fit(self, X, y, groups=None, sample_weight=None): print(_name_estimators((model,))[0][1]) per_fold_preds = [] - ys = [] for num, (train_indices, test_indices) in enumerate(folds): @@ -328,19 +281,18 @@ def fit(self, X, y, groups=None, sample_weight=None): prediction = model.predict_proba(X_test) per_fold_preds.append(prediction) - ys.append(y_train) - all_folds_preds = self._combine_probas(per_fold_preds, ys) + all_folds_preds = np.vstack(per_fold_preds) per_model_preds.append(all_folds_preds) - all_model_predictions = np.hstack(per_model_preds) + meta_features = np.hstack(per_model_preds) if self.store_train_meta_features: # Store the meta features in the order of the # original X, y arrays all_test_indices = np.concatenate([i for _, i in folds]) self.train_meta_features_ = \ - all_model_predictions[np.argsort(all_test_indices)] + meta_features[np.argsort(all_test_indices)] # Fit the base models correctly this time using ALL the training set for model in self.clfs_: @@ -350,17 +302,12 @@ def fit(self, X, y, groups=None, sample_weight=None): model.fit(X, y, sample_weight=sample_weight) # Fit the secondary model - if not self.use_features_in_secondary: - meta_features = all_model_predictions - else: - if sparse.issparse(X): - stack_fn = sparse.hstack - else: - stack_fn = np.hstack - + if self.use_features_in_secondary: reordered_features = self._reorder_with_cv(X, folds) - meta_features = stack_fn((reordered_features, - all_model_predictions)) + meta_features = self._stack_first_level_features( + reordered_features, + meta_features + ) reordered_labels = self._reorder_with_cv(y, folds) if sample_weight is None: @@ -408,19 +355,35 @@ def predict_meta_features(self, X): """ check_is_fitted(self, 'clfs_') - all_model_predictions = np.array([]).reshape(len(X), 0) + + per_model_preds = [] + for model in self.clfs_: if not self.use_probas: - single_model_prediction = model.predict(X) - single_model_prediction = single_model_prediction\ - .reshape(single_model_prediction.shape[0], 1) + prediction = model.predict(X) + prediction = prediction.reshape(prediction.shape[0], 1) else: - single_model_prediction = model.predict_proba(X) - all_model_predictions = np.hstack((all_model_predictions. - astype(single_model_prediction - .dtype), - single_model_prediction)) - return all_model_predictions + prediction = model.predict_proba(X) + + per_model_preds.append(prediction) + + return np.hstack(per_model_preds) + + def _stack_first_level_features(self, X, meta_features): + if sparse.issparse(X): + stack_fn = sparse.hstack + else: + stack_fn = np.hstack + + return stack_fn((X, meta_features)) + + def _do_predict(self, X, predict_fn): + meta_features = self.predict_meta_features(X) + + if self.use_features_in_secondary: + self._stack_first_level_features(X, meta_features) + + return predict_fn(meta_features) def predict(self, X): """ Predict target values for X. @@ -437,16 +400,7 @@ def predict(self, X): Predicted class labels. """ - check_is_fitted(self, 'clfs_') - all_model_predictions = self.predict_meta_features(X) - if not self.use_features_in_secondary: - return self.meta_clf_.predict(all_model_predictions) - elif sparse.issparse(X): - return self.meta_clf_.predict( - sparse.hstack((X, all_model_predictions))) - else: - return self.meta_clf_.predict( - np.hstack((X, all_model_predictions))) + return self._do_predict(X, self.meta_clf_.predict) def predict_proba(self, X): """ Predict class probabilities for X. @@ -463,24 +417,4 @@ def predict_proba(self, X): Probability for each class per sample. """ - check_is_fitted(self, 'clfs_') - all_model_predictions = np.array([]).reshape(len(X), 0) - for model in self.clfs_: - if not self.use_probas: - single_model_prediction = model.predict(X) - single_model_prediction = single_model_prediction\ - .reshape(single_model_prediction.shape[0], 1) - else: - single_model_prediction = model.predict_proba(X) - all_model_predictions = np.hstack((all_model_predictions. - astype(single_model_prediction. - dtype), - single_model_prediction)) - if not self.use_features_in_secondary: - return self.meta_clf_.predict_proba(all_model_predictions) - elif sparse.issparse(X): - self.meta_clf_\ - .predict_proba(sparse.hstack((X, all_model_predictions))) - else: - return self.meta_clf_\ - .predict_proba(np.hstack((X, all_model_predictions))) + return self._do_predict(X, self.meta_clf_.predict_proba) From 755f6cb63eb4ca1fc3940a592b090ea4a5ebf135 Mon Sep 17 00:00:00 2001 From: ackerleytng Date: Sun, 24 Mar 2019 08:44:20 +0800 Subject: [PATCH 4/6] Fix after running test cases --- mlxtend/classifier/stacking_cv_classification.py | 8 ++++++-- mlxtend/classifier/tests/test_stacking_cv_classifier.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mlxtend/classifier/stacking_cv_classification.py b/mlxtend/classifier/stacking_cv_classification.py index 2646af39b..10dd4914c 100644 --- a/mlxtend/classifier/stacking_cv_classification.py +++ b/mlxtend/classifier/stacking_cv_classification.py @@ -354,7 +354,7 @@ def predict_meta_features(self, X): Returns the meta-features for test data. """ - check_is_fitted(self, 'clfs_') + check_is_fitted(self, ['clfs_', 'meta_clf_']) per_model_preds = [] @@ -381,7 +381,7 @@ def _do_predict(self, X, predict_fn): meta_features = self.predict_meta_features(X) if self.use_features_in_secondary: - self._stack_first_level_features(X, meta_features) + meta_features = self._stack_first_level_features(X, meta_features) return predict_fn(meta_features) @@ -400,6 +400,8 @@ def predict(self, X): Predicted class labels. """ + check_is_fitted(self, ['clfs_', 'meta_clf_']) + return self._do_predict(X, self.meta_clf_.predict) def predict_proba(self, X): @@ -417,4 +419,6 @@ def predict_proba(self, X): Probability for each class per sample. """ + check_is_fitted(self, ['clfs_', 'meta_clf_']) + return self._do_predict(X, self.meta_clf_.predict_proba) diff --git a/mlxtend/classifier/tests/test_stacking_cv_classifier.py b/mlxtend/classifier/tests/test_stacking_cv_classifier.py index 957f216d5..45181bacc 100644 --- a/mlxtend/classifier/tests/test_stacking_cv_classifier.py +++ b/mlxtend/classifier/tests/test_stacking_cv_classifier.py @@ -493,8 +493,8 @@ def test_sparse_inputs_with_features_in_secondary(): stclf = StackingCVClassifier(classifiers=[rf, rf], meta_classifier=lr, use_features_in_secondary=True) - X_train, X_test, y_train, y_test = train_test_split(X_breast, y_breast, - test_size=0.3) + X_train, X_test, y_train, y_test = train_test_split(X_breast, y_breast, + test_size=0.3) # dense stclf.fit(X_train, y_train) From 0aee89af2bf85db1d25ed486f80789ac52e13afa Mon Sep 17 00:00:00 2001 From: ackerleytng Date: Sun, 24 Mar 2019 10:46:41 +0800 Subject: [PATCH 5/6] Add regression test --- .../tests/test_stacking_cv_classifier.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/mlxtend/classifier/tests/test_stacking_cv_classifier.py b/mlxtend/classifier/tests/test_stacking_cv_classifier.py index 45181bacc..d75547927 100644 --- a/mlxtend/classifier/tests/test_stacking_cv_classifier.py +++ b/mlxtend/classifier/tests/test_stacking_cv_classifier.py @@ -505,3 +505,41 @@ def test_sparse_inputs_with_features_in_secondary(): stclf.fit(sparse.csr_matrix(X_train), y_train) assert round(stclf.score(X_train, y_train), 2) == 0.99, \ round(stclf.score(X_train, y_train), 2) + + +def test_works_with_df_if_fold_indexes_missing(): + """This is a regression test to make sure fitting will still work even if + training data has ids that cannot be indexed using the indexes from the cv + (e.g. skf) + + Some possibilities: + + Output of the folds are not neatly consecutive (i.e. [341, 345, 543, ...] + instead of [0, 1, ... n]) + + Indexes just start from some number greater than the size of the input + (see test case) + + Training data sometimes has ids that carry other information, and selection + of rows based on cv should not break. + + This is fixed in the code using `safe_indexing` + """ + + np.random.seed(123) + rf = RandomForestClassifier(n_estimators=10) + lr = LogisticRegression(multi_class='ovr', solver='liblinear') + stclf = StackingCVClassifier(classifiers=[rf, rf], + meta_classifier=lr, + use_features_in_secondary=True) + + X_modded = pd.DataFrame(X_breast, + index=np.arange(X_breast.shape[0]) + 1000) + y_modded = pd.Series(y_breast, + index=np.arange(y_breast.shape[0]) + 1000) + + X_train, X_test, y_train, y_test = train_test_split(X_modded, y_modded, + test_size=0.3) + + # dense + stclf.fit(X_train, y_train) + assert round(stclf.score(X_train, y_train), 2) == 0.99, \ + round(stclf.score(X_train, y_train), 2) From 724a0e6622ba74600c3044f5ecf7526c38bcc7cf Mon Sep 17 00:00:00 2001 From: ackerleytng Date: Sun, 31 Mar 2019 22:00:36 +0800 Subject: [PATCH 6/6] Write fit to avoid use of hstack or vstack Also remove need to reorder labels when building meta features --- .../classifier/stacking_cv_classification.py | 121 +++++------------- .../tests/test_stacking_cv_classifier.py | 33 ----- 2 files changed, 35 insertions(+), 119 deletions(-) diff --git a/mlxtend/classifier/stacking_cv_classification.py b/mlxtend/classifier/stacking_cv_classification.py index 10dd4914c..2eb329492 100644 --- a/mlxtend/classifier/stacking_cv_classification.py +++ b/mlxtend/classifier/stacking_cv_classification.py @@ -142,59 +142,6 @@ def __init__(self, classifiers, meta_classifier, self.store_train_meta_features = store_train_meta_features self.use_clones = use_clones - def _fit_fold(self, model, X, y, sample_weight, index, n_splits): - if self.verbose > 0: - print("Training and fitting fold %d of %d..." % - ((index + 1), n_splits)) - - try: - if sample_weight is None: - model.fit(X, y) - else: - model.fit(X, y, - sample_weight=sample_weight) - except TypeError as e: - - if str(e).startswith('A sparse matrix was passed,' - ' but dense' - ' data is required'): - sparse_estimator_message = ( - "\nYou are likely getting this error" - " because one of the" - " estimators" - " does not support sparse matrix input.") - else: - sparse_estimator_message = '' - - raise TypeError(str(e) + sparse_estimator_message + - '\nPlease check that X and y' - 'are NumPy arrays. If X and y are lists' - ' of lists,\ntry passing them as' - ' numpy.array(X)' - ' and numpy.array(y).') - except KeyError as e: - - raise KeyError(str(e) + '\nPlease check that X and y' - ' are NumPy arrays. If X and y are pandas' - ' DataFrames,\ntry passing them as' - ' X.values' - ' and y.values.') - - return model - - def _reorder_with_cv(self, arr, cv): - """Reorders and selects indices from arr using test indices of cv""" - - rows = [safe_indexing(arr, test_indices) - for _, test_indices in cv] - - if sparse.issparse(arr): - stack_fn = sparse.vstack - else: - stack_fn = np.concatenate - - return stack_fn(rows) - def fit(self, X, y, groups=None, sample_weight=None): """ Fit ensemble classifers and the meta-classifier. @@ -239,9 +186,16 @@ def fit(self, X, y, groups=None, sample_weight=None): folds = list(final_cv.split(X, y, groups)) - per_model_preds = [] + # Handle the case of X being a list of lists + # by converting X into a numpy array + if isinstance(X, list): + X = np.array(X) - for model in self.clfs_: + meta_features = None + n_folds = final_cv.get_n_splits() + n_models = len(self.clfs_) + + for n, model in enumerate(self.clfs_): if self.verbose > 0: i = self.clfs_.index(model) + 1 @@ -256,43 +210,41 @@ def fit(self, X, y, groups=None, sample_weight=None): if self.verbose > 1: print(_name_estimators((model,))[0][1]) - per_fold_preds = [] - for num, (train_indices, test_indices) in enumerate(folds): X_train = safe_indexing(X, train_indices) y_train = safe_indexing(y, train_indices) - model = self._fit_fold( - model, - X_train, - y_train, - (safe_indexing(sample_weight, train_indices) - if sample_weight is not None else None), - num, - final_cv.get_n_splits() - ) + if self.verbose > 0: + print("Training and fitting fold %d of %d..." % + ((num + 1), n_folds)) + + if sample_weight is None: + model.fit(X_train, y_train) + else: + w = safe_indexing(sample_weight, train_indices) + model.fit(X_train, y_train, sample_weight=w) X_test = safe_indexing(X, test_indices) if not self.use_probas: - prediction = model.predict(X_test) - prediction = prediction.reshape(prediction.shape[0], 1) + prediction = model.predict(X_test)[:, np.newaxis] else: prediction = model.predict_proba(X_test) - per_fold_preds.append(prediction) - - all_folds_preds = np.vstack(per_fold_preds) - per_model_preds.append(all_folds_preds) - - meta_features = np.hstack(per_model_preds) + if meta_features is None: + # First run, use prediction to get the number of classes + n_classes = prediction.shape[1] + meta_features_shape = (X.shape[0], n_classes * n_models) + meta_features = np.empty(shape=meta_features_shape) + meta_features[np.array(test_indices)[:, np.newaxis], + np.arange(n_classes)] = prediction + else: + row_idx = np.array(test_indices)[:, np.newaxis] + col_idx = np.arange(n_classes) + n * n_classes + meta_features[row_idx, col_idx] = prediction if self.store_train_meta_features: - # Store the meta features in the order of the - # original X, y arrays - all_test_indices = np.concatenate([i for _, i in folds]) - self.train_meta_features_ = \ - meta_features[np.argsort(all_test_indices)] + self.train_meta_features_ = meta_features # Fit the base models correctly this time using ALL the training set for model in self.clfs_: @@ -303,17 +255,15 @@ def fit(self, X, y, groups=None, sample_weight=None): # Fit the secondary model if self.use_features_in_secondary: - reordered_features = self._reorder_with_cv(X, folds) meta_features = self._stack_first_level_features( - reordered_features, + X, meta_features ) - reordered_labels = self._reorder_with_cv(y, folds) if sample_weight is None: - self.meta_clf_.fit(meta_features, reordered_labels) + self.meta_clf_.fit(meta_features, y) else: - self.meta_clf_.fit(meta_features, reordered_labels, + self.meta_clf_.fit(meta_features, y, sample_weight=sample_weight) return self @@ -360,8 +310,7 @@ def predict_meta_features(self, X): for model in self.clfs_: if not self.use_probas: - prediction = model.predict(X) - prediction = prediction.reshape(prediction.shape[0], 1) + prediction = model.predict(X)[:, np.newaxis] else: prediction = model.predict_proba(X) diff --git a/mlxtend/classifier/tests/test_stacking_cv_classifier.py b/mlxtend/classifier/tests/test_stacking_cv_classifier.py index d75547927..735fd890d 100644 --- a/mlxtend/classifier/tests/test_stacking_cv_classifier.py +++ b/mlxtend/classifier/tests/test_stacking_cv_classifier.py @@ -329,39 +329,6 @@ def test_verbose(): sclf.fit(X_iris, y_iris) -def test_list_of_lists(): - X_list = [i for i in X_iris] - meta = LogisticRegression(multi_class='ovr', solver='liblinear') - clf1 = RandomForestClassifier(n_estimators=10) - clf2 = GaussianNB() - sclf = StackingCVClassifier(classifiers=[clf1, clf2], - use_probas=True, - meta_classifier=meta, - shuffle=False, - verbose=0) - - try: - sclf.fit(X_list, y_iris) - except TypeError as e: - assert 'are NumPy arrays. If X and y are lists' in str(e) - - -def test_pandas(): - X_df = pd.DataFrame(X_iris) - meta = LogisticRegression(multi_class='ovr', solver='liblinear') - clf1 = RandomForestClassifier(n_estimators=10) - clf2 = GaussianNB() - sclf = StackingCVClassifier(classifiers=[clf1, clf2], - use_probas=True, - meta_classifier=meta, - shuffle=False, - verbose=0) - try: - sclf.fit(X_df, y_iris) - except KeyError as e: - assert 'are NumPy arrays. If X and y are pandas DataFrames' in str(e) - - def test_get_params(): clf1 = KNeighborsClassifier(n_neighbors=1) clf2 = RandomForestClassifier(random_state=1)