Skip to content

Commit 9c8529a

Browse files
authored
Merge pull request #354 from zdgriffith/efs-fit-params
Adds fit_params support for ExhaustiveFeatureSelector
2 parents c019c87 + d1ef89f commit 9c8529a

File tree

6 files changed

+84
-20
lines changed

6 files changed

+84
-20
lines changed

docs/sources/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ The CHANGELOG for the current development version is available at
1818
##### New Features
1919

2020

21-
The fit method of the SequentialFeatureSelector now optionally accepts **fit_params for the estimator that is used for the feature selection. ([#350](https://github.com/rasbt/mlxtend/pull/350) by Zach Griffith)
21+
- The fit method of the ExhaustiveFeatureSelector now optionally accepts
22+
**fit_params for the estimator that is used for the feature selection. ([#354](https://github.com/rasbt/mlxtend/pull/354) by Zach Griffith)
23+
- The fit method of the SequentialFeatureSelector now optionally accepts
24+
**fit_params for the estimator that is used for the feature selection. ([#350](https://github.com/rasbt/mlxtend/pull/350) by Zach Griffith)
2225

2326

2427
- -

docs/sources/user_guide/feature_selection/ExhaustiveFeatureSelector.ipynb

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,7 +1666,7 @@
16661666
"\n",
16671667
"<hr>\n",
16681668
"\n",
1669-
"*fit(X, y)*\n",
1669+
"*fit(X, y, **fit_params)*\n",
16701670
"\n",
16711671
"Perform feature selection and learn model from training data.\n",
16721672
"\n",
@@ -1681,14 +1681,18 @@
16811681
"\n",
16821682
" Target values.\n",
16831683
"\n",
1684+
"- `fit_params` : dict of string -> object, optional\n",
1685+
"\n",
1686+
" Parameters to pass to to the fit method of classifier.\n",
1687+
"\n",
16841688
"**Returns**\n",
16851689
"\n",
16861690
"- `self` : object\n",
16871691
"\n",
16881692
"\n",
16891693
"<hr>\n",
16901694
"\n",
1691-
"*fit_transform(X, y)*\n",
1695+
"*fit_transform(X, y, **fit_params)*\n",
16921696
"\n",
16931697
"Fit to training data and return the best selected features from X.\n",
16941698
"\n",
@@ -1699,6 +1703,14 @@
16991703
" Training vectors, where n_samples is the number of samples and\n",
17001704
" n_features is the number of features.\n",
17011705
"\n",
1706+
"- `y` : array-like, shape = [n_samples]\n",
1707+
"\n",
1708+
" Target values.\n",
1709+
"\n",
1710+
"- `fit_params` : dict of string -> object, optional\n",
1711+
"\n",
1712+
" Parameters to pass to to the fit method of classifier.\n",
1713+
"\n",
17021714
"**Returns**\n",
17031715
"\n",
17041716
"Feature subset of X, shape={n_samples, k_features}\n",
@@ -1815,7 +1827,7 @@
18151827
"name": "python",
18161828
"nbconvert_exporter": "python",
18171829
"pygments_lexer": "ipython3",
1818-
"version": "3.6.1"
1830+
"version": "3.6.3"
18191831
}
18201832
},
18211833
"nbformat": 4,

docs/sources/user_guide/feature_selection/SequentialFeatureSelector.ipynb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,10 @@
15921592
" Training vectors, where n_samples is the number of samples and\n",
15931593
" n_features is the number of features.\n",
15941594
"\n",
1595+
"- `y` : array-like, shape = [n_samples]\n",
1596+
"\n",
1597+
" Target values.\n",
1598+
"\n",
15951599
"- `fit_params` : dict of string -> object, optional\n",
15961600
"\n",
15971601
" Parameters to pass to to the fit method of classifier.\n",

mlxtend/feature_selection/exhaustive_feature_selector.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,17 @@
2525
from sklearn.externals.joblib import Parallel, delayed
2626

2727

28-
def _calc_score(selector, X, y, indices):
28+
def _calc_score(selector, X, y, indices, **fit_params):
2929
if selector.cv:
3030
scores = cross_val_score(selector.est_,
3131
X[:, indices], y,
3232
cv=selector.cv,
3333
scoring=selector.scorer,
3434
n_jobs=1,
35-
pre_dispatch=selector.pre_dispatch)
35+
pre_dispatch=selector.pre_dispatch,
36+
fit_params=fit_params)
3637
else:
37-
selector.est_.fit(X[:, indices], y)
38+
selector.est_.fit(X[:, indices], y, **fit_params)
3839
scores = np.array([selector.scorer(selector.est_, X[:, indices], y)])
3940
return indices, scores
4041

@@ -127,7 +128,7 @@ def __init__(self, estimator, min_features=1, max_features=1,
127128
self.est_ = self.estimator
128129
self.fitted = False
129130

130-
def fit(self, X, y):
131+
def fit(self, X, y, **fit_params):
131132
"""Perform feature selection and learn model from training data.
132133
133134
Parameters
@@ -137,6 +138,8 @@ def fit(self, X, y):
137138
n_features is the number of features.
138139
y : array-like, shape = [n_samples]
139140
Target values.
141+
fit_params : dict of string -> object, optional
142+
Parameters to pass to to the fit method of classifier.
140143
141144
Returns
142145
-------
@@ -160,41 +163,42 @@ def fit(self, X, y):
160163
raise AttributeError('min_features must be <= max_features')
161164

162165
candidates = chain(*((combinations(range(X.shape[1]), r=i))
163-
for i in range(self.min_features,
164-
self.max_features + 1)))
166+
for i in range(self.min_features,
167+
self.max_features + 1)))
165168

166169
self.subsets_ = {}
167-
170+
168171
def ncr(n, r):
169172
"""Return the number of combinations of length r from n items.
170-
173+
171174
Parameters
172175
----------
173176
n : {integer}
174177
Total number of items
175178
r : {integer}
176179
Number of items to select from n
177-
180+
178181
Returns
179182
-------
180183
Number of combinations, integer
181-
184+
182185
"""
183-
186+
184187
r = min(r, n-r)
185188
if r == 0:
186189
return 1
187190
numer = reduce(op.mul, range(n, n-r, -1))
188191
denom = reduce(op.mul, range(1, r+1))
189192
return numer//denom
190-
193+
191194
all_comb = np.sum([ncr(n=X.shape[1], r=i)
192195
for i in range(self.min_features,
193196
self.max_features + 1)])
194-
197+
195198
n_jobs = min(self.n_jobs, all_comb)
196199
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=self.pre_dispatch)
197-
work = enumerate(parallel(delayed(_calc_score)(self, X, y, c)
200+
work = enumerate(parallel(delayed(_calc_score)
201+
(self, X, y, c, **fit_params)
198202
for c in candidates))
199203

200204
for iteration, (c, cv_scores) in work:
@@ -239,21 +243,25 @@ def transform(self, X):
239243
self._check_fitted()
240244
return X[:, self.best_idx_]
241245

242-
def fit_transform(self, X, y):
246+
def fit_transform(self, X, y, **fit_params):
243247
"""Fit to training data and return the best selected features from X.
244248
245249
Parameters
246250
----------
247251
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
248252
Training vectors, where n_samples is the number of samples and
249253
n_features is the number of features.
254+
y : array-like, shape = [n_samples]
255+
Target values.
256+
fit_params : dict of string -> object, optional
257+
Parameters to pass to to the fit method of classifier.
250258
251259
Returns
252260
-------
253261
Feature subset of X, shape={n_samples, k_features}
254262
255263
"""
256-
self.fit(X, y)
264+
self.fit(X, y, **fit_params)
257265
return self.transform(X)
258266

259267
def get_metric_dict(self, confidence_interval=0.95):

mlxtend/feature_selection/sequential_feature_selector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ def fit_transform(self, X, y, **fit_params):
481481
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
482482
Training vectors, where n_samples is the number of samples and
483483
n_features is the number of features.
484+
y : array-like, shape = [n_samples]
485+
Target values.
484486
fit_params : dict of string -> object, optional
485487
Parameters to pass to to the fit method of classifier.
486488

mlxtend/feature_selection/tests/test_exhaustive_feature_selector.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
from numpy.testing import assert_almost_equal
1010
from mlxtend.feature_selection import ExhaustiveFeatureSelector as EFS
11+
from sklearn.ensemble import RandomForestClassifier
1112
from sklearn.neighbors import KNeighborsClassifier
1213
from mlxtend.classifier import SoftmaxRegression
1314
from sklearn.datasets import load_iris
@@ -164,6 +165,40 @@ def test_knn_cv3():
164165
assert round(efs1.best_score_, 4) == 0.9728
165166

166167

168+
def test_fit_params():
169+
iris = load_iris()
170+
X = iris.data
171+
y = iris.target
172+
sample_weight = np.ones(X.shape[0])
173+
forest = RandomForestClassifier(n_estimators=100, random_state=123)
174+
efs1 = EFS(forest,
175+
min_features=3,
176+
max_features=3,
177+
scoring='accuracy',
178+
cv=4,
179+
print_progress=False)
180+
efs1 = efs1.fit(X, y, sample_weight=sample_weight)
181+
expect = {0: {'feature_idx': (0, 1, 2),
182+
'cv_scores': np.array([0.94871795, 0.92307692,
183+
0.91666667, 0.97222222]),
184+
'avg_score': 0.9401709401709402},
185+
1: {'feature_idx': (0, 1, 3),
186+
'cv_scores': np.array([0.92307692, 0.92307692,
187+
0.88888889, 1.]),
188+
'avg_score': 0.9337606837606838},
189+
2: {'feature_idx': (0, 2, 3),
190+
'cv_scores': np.array([0.97435897, 0.94871795,
191+
0.94444444, 0.97222222]),
192+
'avg_score': 0.9599358974358974},
193+
3: {'feature_idx': (1, 2, 3),
194+
'cv_scores': np.array([0.97435897, 0.94871795,
195+
0.91666667, 1.]),
196+
'avg_score': 0.9599358974358974}}
197+
dict_compare_utility(d1=expect, d2=efs1.subsets_)
198+
assert efs1.best_idx_ == (0, 2, 3)
199+
assert round(efs1.best_score_, 4) == 0.9599
200+
201+
167202
def test_regression():
168203
boston = load_boston()
169204
X, y = boston.data[:, [1, 2, 6, 8, 12]], boston.target

0 commit comments

Comments
 (0)