Skip to content

Commit 1c6d548

Browse files
committed
Proper Numpy solution for Categorize (bug-fix and also 10X speedup).
1 parent d14a4fa commit 1c6d548

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

histogrammar/primitives/categorize.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,16 +237,27 @@ def _numpy(self, data, weights, shape):
237237
self._checkNPQuantity(q, shape)
238238
self._checkNPWeights(weights, shape)
239239
weights = self._makeNPWeights(weights, shape)
240+
newentries = weights.sum()
241+
242+
subweights = weights.copy()
243+
subweights[weights < 0.0] = 0.0
240244

241-
# no possibility of exception from here on out (for rollback)
242-
for x, w in zip(q, weights):
243-
if w > 0.0:
244-
if x not in self.bins:
245-
self.bins[x] = self.value.zero()
246-
self.bins[x].fill(x, w)
245+
import numpy
246+
selection = numpy.empty(q.shape, dtype=numpy.bool)
247+
248+
uniques, inverse = numpy.unique(q, return_inverse=True)
247249

248250
# no possibility of exception from here on out (for rollback)
249-
self.entries += float(weights.sum())
251+
for i, x in enumerate(uniques):
252+
if x not in self.bins:
253+
self.bins[x] = self.value.zero()
254+
255+
numpy.not_equal(inverse, i, selection)
256+
subweights[:] = weights
257+
subweights[selection] = 0.0
258+
self.bins[x]._numpy(data, subweights, shape)
259+
260+
self.entries += float(newentries)
250261

251262
def _sparksql(self, jvm, converter):
252263
return converter.Categorize(self.quantity.asSparkSQL(), self.value._sparksql(jvm, converter))

0 commit comments

Comments
 (0)