Skip to content

Commit 115278b

Browse files
harenbergsdrasbt
authored andcommitted
fix fpmax issue (#570) with fptrees that contain no nodes (#573)
* fix fpmax issue (#570) with fptrees that contain no nodes * Add additional unit test for pattern mining. Also refactored tests. * update changelog * bumb version to 0.18.0dev0 * add unit test for min_support=0.
1 parent ac0f0c1 commit 115278b

File tree

9 files changed

+256
-51
lines changed

9 files changed

+256
-51
lines changed

docs/sources/CHANGELOG.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@ The CHANGELOG for the current development version is available at
77

88
---
99

10+
### Version 0.18.0 (TBD)
11+
12+
##### Downloads
13+
14+
- [Source code (zip)](https://github.com/rasbt/mlxtend/archive/v0.18.0.zip)
15+
16+
- [Source code (tar.gz)](https://github.com/rasbt/mlxtend/archive/v0.18.0.tar.gz)
17+
18+
##### New Features
19+
20+
- -
21+
22+
##### Changes
23+
24+
- -
25+
26+
##### Bug Fixes
27+
28+
- Behavior of `fpgrowth` and `apriori` consistent for edgecases such as `min_support=0`. ([#573](https://github.com/rasbt/mlxtend/pull/550) via [Steve Harenberg](https://github.com/harenbergsd))
29+
- `fpmax` returns an empty data frame now instead of raising an error if the frequent itemset set is empty. ([#573](https://github.com/rasbt/mlxtend/pull/550) via [Steve Harenberg](https://github.com/harenbergsd))
30+
31+
32+
33+
34+
1035
### Version 0.17.0 (07/19/2019)
1136

1237
##### Downloads

mlxtend/frequent_patterns/apriori.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def _support(_x, _n_rows, _is_sparse):
143143
out = (np.sum(_x, axis=0) / _n_rows)
144144
return np.array(out).reshape(-1)
145145

146+
if min_support <= 0.:
147+
raise ValueError('`min_support` must be a positive '
148+
'number within the interval `(0, 1]`. '
149+
'Got %s.' % min_support)
150+
146151
idxs = np.where((df.values != 1) & (df.values != 0))
147152
if len(idxs[0]) > 0:
148153
val = df.values[idxs[0][0], idxs[1][0]]

mlxtend/frequent_patterns/fpcommon.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def insert_itemset(self, itemset, count=1):
123123
count : int
124124
The number of occurrences of the itemset.
125125
"""
126+
self.root.count += count
127+
126128
if len(itemset) == 0:
127129
return
128130

@@ -162,7 +164,7 @@ def print_status(self, count, colnames):
162164

163165

164166
class FPNode(object):
165-
def __init__(self, item, count=1, parent=None):
167+
def __init__(self, item, count=0, parent=None):
166168
self.item = item
167169
self.count = count
168170
self.parent = parent

mlxtend/frequent_patterns/fpgrowth.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def fpgrowth(df, min_support=0.5, use_colnames=False, max_len=None, verbose=0):
6060
"""
6161
fpc.valid_input_check(df)
6262

63+
if min_support <= 0.:
64+
raise ValueError('`min_support` must be a positive '
65+
'number within the interval `(0, 1]`. '
66+
'Got %s.' % min_support)
67+
6368
colname_map = None
6469
if use_colnames:
6570
colname_map = {idx: item for idx, item in enumerate(df.columns)}

mlxtend/frequent_patterns/fpmax.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ def fpmax(df, min_support=0.5, use_colnames=False, max_len=None, verbose=0):
6161
"""
6262
fpc.valid_input_check(df)
6363

64+
if min_support <= 0.:
65+
raise ValueError('`min_support` must be a positive '
66+
'number within the interval `(0, 1]`. '
67+
'Got %s.' % min_support)
68+
6469
colname_map = None
6570
if use_colnames:
6671
colname_map = {idx: item for idx, item in enumerate(df.columns)}
@@ -78,14 +83,16 @@ def fpmax_step(tree, minsup, mfit, colnames, max_len, verbose):
7883
count = 0
7984
items = list(tree.nodes.keys())
8085
largest_set = sorted(tree.cond_items+items, key=mfit.rank.get)
86+
if len(largest_set) == 0:
87+
return
8188
if tree.is_path():
8289
if not mfit.contains(largest_set):
8390
count += 1
8491
largest_set.reverse()
8592
mfit.cache = largest_set
8693
mfit.insert_itemset(largest_set)
8794
if max_len is None or len(largest_set) <= max_len:
88-
support = min([tree.nodes[i][0].count for i in items])
95+
support = tree.root.count
8996
yield support, largest_set
9097

9198
if verbose:

mlxtend/frequent_patterns/tests/test_apriori.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,36 @@
66

77
import unittest
88
import numpy as np
9-
from mlxtend.frequent_patterns.tests.test_fpbase import FPTestAll
9+
from test_fpbase import FPTestEdgeCases, FPTestErrors, \
10+
FPTestEx1All, FPTestEx2All, FPTestEx3All
1011
from mlxtend.frequent_patterns import apriori
1112

1213

1314
def apriori_wrapper_low_memory(*args, **kwargs):
1415
return apriori(*args, **kwargs, low_memory=True)
1516

1617

17-
class TestApriori(unittest.TestCase, FPTestAll):
18+
class TestEdgeCases(unittest.TestCase, FPTestEdgeCases):
1819
def setUp(self):
19-
FPTestAll.setUp(self, apriori)
20+
FPTestEdgeCases.setUp(self, apriori)
2021

2122

22-
class TestAprioriLowMemory(unittest.TestCase, FPTestAll):
23+
class TestErrors(unittest.TestCase, FPTestErrors):
2324
def setUp(self):
24-
FPTestAll.setUp(self, apriori_wrapper_low_memory)
25+
FPTestErrors.setUp(self, apriori)
2526

2627

27-
class TestAprioriBinaryInput(unittest.TestCase, FPTestAll):
28+
class TestApriori(unittest.TestCase, FPTestEx1All):
29+
def setUp(self):
30+
FPTestEx1All.setUp(self, apriori)
31+
32+
33+
class TestAprioriLowMemory(unittest.TestCase, FPTestEx1All):
34+
def setUp(self):
35+
FPTestEx1All.setUp(self, apriori_wrapper_low_memory)
36+
37+
38+
class TestAprioriBoolInput(unittest.TestCase, FPTestEx1All):
2839
def setUp(self):
2940
one_ary = np.array(
3041
[[False, False, False, True, False, True, True, True, True,
@@ -37,4 +48,14 @@ def setUp(self):
3748
True, True],
3849
[False, True, False, True, True, True, False, False, True,
3950
False, False]])
40-
FPTestAll.setUp(self, apriori, one_ary=one_ary)
51+
FPTestEx1All.setUp(self, apriori, one_ary=one_ary)
52+
53+
54+
class TestEx2(unittest.TestCase, FPTestEx2All):
55+
def setUp(self):
56+
FPTestEx2All.setUp(self, apriori)
57+
58+
59+
class TestEx3(unittest.TestCase, FPTestEx3All):
60+
def setUp(self):
61+
FPTestEx3All.setUp(self, apriori)

mlxtend/frequent_patterns/tests/test_fpbase.py

Lines changed: 100 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
from numpy.testing import assert_array_equal
99
from mlxtend.utils import assert_raises
10+
from mlxtend.preprocessing import TransactionEncoder
1011
import pandas as pd
1112
import sys
1213
from contextlib import contextmanager
@@ -24,30 +25,37 @@ def captured_output():
2425
sys.stdout, sys.stderr = old_out, old_err
2526

2627

27-
class FPTestBase(object):
28+
class FPTestEdgeCases(object):
2829
"""
29-
Base testing class for frequent pattern mining. This class should include
30-
setup and tests common to all methods (e.g., error for improper input)
30+
Base class for testing edge cases for pattern mining.
3131
"""
3232

33-
def setUp(self, fpalgo, one_ary=None):
34-
if one_ary is None:
35-
self.one_ary = np.array(
36-
[[0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1],
37-
[0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1],
38-
[1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0],
39-
[0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1],
40-
[0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0]])
33+
def setUp(self, fpalgo):
34+
self.fpalgo = fpalgo
4135

42-
else:
43-
self.one_ary = one_ary
36+
def test_empty(self):
37+
df = pd.DataFrame([[]])
38+
res_df = self.fpalgo(df)
39+
expect = pd.DataFrame([], columns=['support', 'itemsets'])
40+
compare_dataframes(res_df, expect)
4441

42+
43+
class FPTestErrors(object):
44+
"""
45+
Base class for testing expected errors for pattern mining.
46+
"""
47+
48+
def setUp(self, fpalgo):
49+
self.one_ary = np.array(
50+
[[0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1],
51+
[0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1],
52+
[1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0],
53+
[0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1],
54+
[0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0]])
4555
self.cols = ['Apple', 'Corn', 'Dill', 'Eggs', 'Ice cream',
4656
'Kidney Beans', 'Milk',
4757
'Nutmeg', 'Onion', 'Unicorn', 'Yogurt']
48-
4958
self.df = pd.DataFrame(self.one_ary, columns=self.cols)
50-
5159
self.fpalgo = fpalgo
5260

5361
def test_itemsets_type(self):
@@ -84,6 +92,31 @@ def test_sparsedataframe_notzero_column(self):
8492
'`df.columns = [str(i) for i in df.columns`].',
8593
self.fpalgo, dfs)
8694

95+
96+
class FPTestEx1(object):
97+
"""
98+
Base class for testing frequent pattern mining on a small example.
99+
"""
100+
101+
def setUp(self, fpalgo, one_ary=None):
102+
if one_ary is None:
103+
self.one_ary = np.array(
104+
[[0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1],
105+
[0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1],
106+
[1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0],
107+
[0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1],
108+
[0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0]])
109+
else:
110+
self.one_ary = one_ary
111+
112+
self.cols = ['Apple', 'Corn', 'Dill', 'Eggs', 'Ice cream',
113+
'Kidney Beans', 'Milk',
114+
'Nutmeg', 'Onion', 'Unicorn', 'Yogurt']
115+
116+
self.df = pd.DataFrame(self.one_ary, columns=self.cols)
117+
118+
self.fpalgo = fpalgo
119+
87120
def test_frozenset_selection(self):
88121
res_df = self.fpalgo(self.df, use_colnames=True)
89122
assert res_df.values.shape == self.fpalgo(self.df).values.shape
@@ -117,9 +150,9 @@ def test_with_fill_values(fill_value):
117150
test_with_fill_values(False)
118151

119152

120-
class FPTestAll(FPTestBase):
153+
class FPTestEx1All(FPTestEx1):
121154
def setUp(self, fpalgo, one_ary=None):
122-
FPTestBase.setUp(self, fpalgo, one_ary=one_ary)
155+
FPTestEx1.setUp(self, fpalgo, one_ary=one_ary)
123156

124157
def test_default(self):
125158
res_df = self.fpalgo(self.df)
@@ -162,27 +195,62 @@ def test_low_memory_flag(self):
162195
assert True
163196

164197

165-
class FPTestMaximal(FPTestBase):
166-
def setUp(self, fpalgo, one_ary=None):
167-
FPTestBase.setUp(self, fpalgo, one_ary=one_ary)
198+
class FPTestEx2(object):
199+
"""
200+
Base class for testing frequent pattern mining on a small example.
201+
"""
168202

169-
def test_default(self):
170-
res_df = self.fpalgo(self.df)
171-
expect = pd.DataFrame([[0.6, frozenset([5, 6])],
172-
[0.6, frozenset([5, 10])],
173-
[0.6, frozenset([3, 5, 8])]],
203+
def setUp(self):
204+
database = [['a'], ['b'], ['c', 'd'], ['e']]
205+
te = TransactionEncoder()
206+
te_ary = te.fit(database).transform(database)
207+
208+
self.df = pd.DataFrame(te_ary, columns=te.columns_)
209+
210+
211+
class FPTestEx2All(FPTestEx2):
212+
def setUp(self, fpalgo):
213+
self.fpalgo = fpalgo
214+
FPTestEx2.setUp(self)
215+
216+
def test_output(self):
217+
res_df = self.fpalgo(self.df, min_support=0.001, use_colnames=True)
218+
expect = pd.DataFrame([[0.25, frozenset(['a'])],
219+
[0.25, frozenset(['b'])],
220+
[0.25, frozenset(['c'])],
221+
[0.25, frozenset(['d'])],
222+
[0.25, frozenset(['e'])],
223+
[0.25, frozenset(['c', 'd'])]],
174224
columns=['support', 'itemsets'])
175225

176226
compare_dataframes(res_df, expect)
177227

178-
def test_max_len(self):
179-
res_df1 = self.fpalgo(self.df)
180-
max_len = np.max(res_df1['itemsets'].apply(len))
181-
assert max_len == 3
182228

183-
res_df2 = self.fpalgo(self.df, max_len=2)
184-
max_len = np.max(res_df2['itemsets'].apply(len))
185-
assert max_len == 2
229+
class FPTestEx3(object):
230+
"""
231+
Base class for testing frequent pattern mining on a small example.
232+
"""
233+
234+
def setUp(self):
235+
database = [['a'], ['b'], ['c', 'd'], ['e']]
236+
te = TransactionEncoder()
237+
te_ary = te.fit(database).transform(database)
238+
239+
self.df = pd.DataFrame(te_ary, columns=te.columns_)
240+
241+
242+
class FPTestEx3All(FPTestEx3):
243+
def setUp(self, fpalgo):
244+
self.fpalgo = fpalgo
245+
FPTestEx3.setUp(self)
246+
247+
def test_output3(self):
248+
assert_raises(ValueError,
249+
'`min_support` must be a positive '
250+
'number within the interval `(0, 1]`. Got 0.0.',
251+
self.fpalgo,
252+
self.df,
253+
min_support=0.)
186254

187255

188256
def compare_dataframes(df1, df2):
Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
import unittest
22
import numpy as np
3-
from mlxtend.frequent_patterns.tests.test_fpbase import FPTestAll
3+
from test_fpbase import FPTestEdgeCases, FPTestErrors, \
4+
FPTestEx1All, FPTestEx2All, FPTestEx3All
45
from mlxtend.frequent_patterns import fpgrowth
56

67

7-
class TestFPGrowth(unittest.TestCase, FPTestAll):
8+
class TestEdgeCases(unittest.TestCase, FPTestEdgeCases):
89
def setUp(self):
9-
FPTestAll.setUp(self, fpgrowth)
10+
FPTestEdgeCases.setUp(self, fpgrowth)
1011

1112

12-
class TestFPGrowth2(unittest.TestCase, FPTestAll):
13+
class TestErrors(unittest.TestCase, FPTestErrors):
14+
def setUp(self):
15+
FPTestErrors.setUp(self, fpgrowth)
16+
17+
18+
class TestEx1(unittest.TestCase, FPTestEx1All):
19+
def setUp(self):
20+
FPTestEx1All.setUp(self, fpgrowth)
21+
22+
23+
class TestEx1BoolInput(unittest.TestCase, FPTestEx1All):
1324
def setUp(self):
1425
one_ary = np.array(
1526
[[False, False, False, True, False, True, True, True, True,
@@ -22,4 +33,14 @@ def setUp(self):
2233
True, True],
2334
[False, True, False, True, True, True, False, False, True,
2435
False, False]])
25-
FPTestAll.setUp(self, fpgrowth, one_ary=one_ary)
36+
FPTestEx1All.setUp(self, fpgrowth, one_ary=one_ary)
37+
38+
39+
class TestEx2(unittest.TestCase, FPTestEx2All):
40+
def setUp(self):
41+
FPTestEx2All.setUp(self, fpgrowth)
42+
43+
44+
class TestEx3(unittest.TestCase, FPTestEx3All):
45+
def setUp(self):
46+
FPTestEx3All.setUp(self, fpgrowth)

0 commit comments

Comments
 (0)