Skip to content

Commit 5735e00

Browse files
authored
allow grid search for classifiers/regressors params in ensemble methods (#259)
1 parent 00981a6 commit 5735e00

16 files changed

+266
-22
lines changed

docs/sources/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ The CHANGELOG for the current development version is available at
1818
- Added `evaluate.permutation_test`, a permutation test for hypothesis testing (or A/B testing) to test if two samples come from the same distribution. Or in other words, a procedure to test the null hypothesis that that two groups are not significantly different (e.g., a treatment and a control group).
1919
- Added `'leverage'` and `'conviction` as evaluation metrics to the `frequent_patterns.association_rules` function. [#246](https://github.com/rasbt/mlxtend/pull/246) & [#247](https://github.com/rasbt/mlxtend/pull/247)
2020
- Added a `loadings_` attribute to `PrincipalComponentAnalysis` to compute the factor loadings of the features on the principal components. [#251](https://github.com/rasbt/mlxtend/pull/251)
21+
- Allow grid search over classifiers/regressors in ensemble and stacking estimators [#259](https://github.com/rasbt/mlxtend/pull/259)
2122

2223
##### Changes
2324

docs/sources/user_guide/classifier/EnsembleVoteClassifier.ipynb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,20 @@
459459
"grid = grid.fit(iris.data, iris.target)"
460460
]
461461
},
462+
{
463+
"cell_type": "markdown",
464+
"metadata": {},
465+
"source": [
466+
"**Note**\n",
467+
"\n",
468+
"The `EnsembleVoteClass` also enables grid search over the `clfs` argument. However, due to the current implementation of `GridSearchCV` in scikit-learn, it is not possible to search over both, differenct classifiers and classifier parameters at the same time. For instance, while the following parameter dictionary works\n",
469+
"\n",
470+
" params = {'randomforestclassifier__n_estimators': [1, 100],\n",
471+
" 'clfs': [(clf1, clf1, clf1), (clf2, clf3)]}\n",
472+
" \n",
473+
"it will use the instance settings of `clf1`, `clf2`, and `clf3` and not overwrite it with the `'n_estimators'` settings from `'randomforestclassifier__n_estimators': [1, 100]`."
474+
]
475+
},
462476
{
463477
"cell_type": "markdown",
464478
"metadata": {},

docs/sources/user_guide/classifier/StackingCVClassifier.ipynb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,20 @@
423423
"print('Accuracy: %.2f' % grid.best_score_)"
424424
]
425425
},
426+
{
427+
"cell_type": "markdown",
428+
"metadata": {},
429+
"source": [
430+
"**Note**\n",
431+
"\n",
432+
"The `StackingCVClassifier` also enables grid search over the `classifiers` argument. However, due to the current implementation of `GridSearchCV` in scikit-learn, it is not possible to search over both, differenct classifiers and classifier parameters at the same time. For instance, while the following parameter dictionary works\n",
433+
"\n",
434+
" params = {'randomforestclassifier__n_estimators': [1, 100],\n",
435+
" 'classifiers': [(clf1, clf1, clf1), (clf2, clf3)]}\n",
436+
" \n",
437+
"it will use the instance settings of `clf1`, `clf2`, and `clf3` and not overwrite it with the `'n_estimators'` settings from `'randomforestclassifier__n_estimators': [1, 100]`."
438+
]
439+
},
426440
{
427441
"cell_type": "markdown",
428442
"metadata": {},

docs/sources/user_guide/classifier/StackingClassifier.ipynb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,20 @@
400400
"print('Accuracy: %.2f' % grid.best_score_)"
401401
]
402402
},
403+
{
404+
"cell_type": "markdown",
405+
"metadata": {},
406+
"source": [
407+
"**Note**\n",
408+
"\n",
409+
"The `StackingClassifier` also enables grid search over the `classifiers` argument. However, due to the current implementation of `GridSearchCV` in scikit-learn, it is not possible to search over both, differenct classifiers and classifier parameters at the same time. For instance, while the following parameter dictionary works\n",
410+
"\n",
411+
" params = {'randomforestclassifier__n_estimators': [1, 100],\n",
412+
" 'classifiers': [(clf1, clf1, clf1), (clf2, clf3)]}\n",
413+
" \n",
414+
"it will use the instance settings of `clf1`, `clf2`, and `clf3` and not overwrite it with the `'n_estimators'` settings from `'randomforestclassifier__n_estimators': [1, 100]`."
415+
]
416+
},
403417
{
404418
"cell_type": "markdown",
405419
"metadata": {},

docs/sources/user_guide/regressor/StackingCVRegressor.ipynb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,20 @@
278278
"print('Accuracy: %.2f' % grid.best_score_)"
279279
]
280280
},
281+
{
282+
"cell_type": "markdown",
283+
"metadata": {},
284+
"source": [
285+
"**Note**\n",
286+
"\n",
287+
"The `StackingCVRegressor` also enables grid search over the `regressors` argument. However, due to the current implementation of `GridSearchCV` in scikit-learn, it is not possible to search over both, differenct classifiers and classifier parameters at the same time. For instance, while the following parameter dictionary works\n",
288+
"\n",
289+
" params = {'randomforestregressor__n_estimators': [1, 100],\n",
290+
" 'regressors': [(regr1, regr1, regr1), (regr2, regr3)]}\n",
291+
" \n",
292+
"it will use the instance settings of `regr1`, `regr2`, and `regr3` and not overwrite it with the `'n_estimators'` settings from `'randomforestregressor__n_estimators': [1, 100]`."
293+
]
294+
},
281295
{
282296
"cell_type": "markdown",
283297
"metadata": {},

docs/sources/user_guide/regressor/StackingRegressor.ipynb

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@
7777
{
7878
"cell_type": "code",
7979
"execution_count": 2,
80-
"metadata": {},
80+
"metadata": {
81+
"collapsed": true
82+
},
8183
"outputs": [],
8284
"source": [
8385
"from mlxtend.regressor import StackingRegressor\n",
@@ -604,7 +606,14 @@
604606
"cell_type": "markdown",
605607
"metadata": {},
606608
"source": [
607-
"In case we are planning to use a regression algorithm multiple times, all we need to do is to add an additional number suffix in the parameter grid as shown below:"
609+
"**Note**\n",
610+
"\n",
611+
"The `StackingRegressor` also enables grid search over the `regressors` argument. However, due to the current implementation of `GridSearchCV` in scikit-learn, it is not possible to search over both, differenct classifiers and classifier parameters at the same time. For instance, while the following parameter dictionary works\n",
612+
"\n",
613+
" params = {'randomforestregressor__n_estimators': [1, 100],\n",
614+
" 'regressors': [(regr1, regr1, regr1), (regr2, regr3)]}\n",
615+
" \n",
616+
"it will use the instance settings of `regr1`, `regr2`, and `regr3` and not overwrite it with the `'n_estimators'` settings from `'randomforestregressor__n_estimators': [1, 100]`."
608617
]
609618
},
610619
{

mlxtend/classifier/ensemble_vote.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,7 @@ def get_params(self, deep=True):
255255

256256
for key, value in six.iteritems(super(EnsembleVoteClassifier,
257257
self).get_params(deep=False)):
258-
if key == 'clfs':
259-
continue
260-
else:
261-
out['%s' % key] = value
258+
out['%s' % key] = value
262259
return out
263260

264261
def _predict(self, X):

mlxtend/classifier/stacking_classification.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,7 @@ def get_params(self, deep=True):
141141

142142
for key, value in six.iteritems(super(StackingClassifier,
143143
self).get_params(deep=False)):
144-
if key in ('classifiers', 'meta-classifier'):
145-
continue
146-
else:
147-
out['%s' % key] = value
144+
out['%s' % key] = value
148145

149146
return out
150147

mlxtend/classifier/stacking_cv_classification.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,7 @@ def get_params(self, deep=True):
245245

246246
for key, value in six.iteritems(super(StackingCVClassifier,
247247
self).get_params(deep=False)):
248-
if key in ('classifiers', 'meta-classifier'):
249-
continue
250-
else:
251-
out['%s' % key] = value
248+
out['%s' % key] = value
252249

253250
return out
254251

mlxtend/classifier/tests/test_ensemble_vote_classifier.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sklearn.linear_model import LogisticRegression
99
from sklearn.naive_bayes import GaussianNB
1010
from sklearn.ensemble import RandomForestClassifier
11+
from sklearn.neighbors import KNeighborsClassifier
1112
import numpy as np
1213
from sklearn import datasets
1314
from sklearn.model_selection import GridSearchCV
@@ -86,3 +87,38 @@ def test_EnsembleVoteClassifier_gridsearch_enumerate_names():
8687

8788
grid = GridSearchCV(estimator=eclf, param_grid=params, cv=5)
8889
grid = grid.fit(iris.data, iris.target)
90+
91+
92+
def test_get_params():
93+
clf1 = KNeighborsClassifier(n_neighbors=1)
94+
clf2 = RandomForestClassifier(random_state=1)
95+
clf3 = GaussianNB()
96+
eclf = EnsembleVoteClassifier(clfs=[clf1, clf2, clf3])
97+
98+
got = sorted(list({s.split('__')[0] for s in eclf.get_params().keys()}))
99+
expect = ['clfs',
100+
'gaussiannb',
101+
'kneighborsclassifier',
102+
'randomforestclassifier',
103+
'refit',
104+
'verbose',
105+
'voting',
106+
'weights']
107+
assert got == expect, got
108+
109+
110+
def test_classifier_gridsearch():
111+
clf1 = KNeighborsClassifier(n_neighbors=1)
112+
clf2 = RandomForestClassifier(random_state=1)
113+
clf3 = GaussianNB()
114+
eclf = EnsembleVoteClassifier(clfs=[clf1])
115+
116+
params = {'clfs': [[clf1, clf1, clf1], [clf2, clf3]]}
117+
118+
grid = GridSearchCV(estimator=eclf,
119+
param_grid=params,
120+
cv=5,
121+
refit=True)
122+
grid.fit(X, y)
123+
124+
assert len(grid.best_params_['clfs']) == 2

0 commit comments

Comments
 (0)