@@ -237,16 +237,27 @@ def _numpy(self, data, weights, shape):
237
237
self ._checkNPQuantity (q , shape )
238
238
self ._checkNPWeights (weights , shape )
239
239
weights = self ._makeNPWeights (weights , shape )
240
+ newentries = weights .sum ()
241
+
242
+ subweights = weights .copy ()
243
+ subweights [weights < 0.0 ] = 0.0
240
244
241
- # no possibility of exception from here on out (for rollback)
242
- for x , w in zip (q , weights ):
243
- if w > 0.0 :
244
- if x not in self .bins :
245
- self .bins [x ] = self .value .zero ()
246
- self .bins [x ].fill (x , w )
245
+ import numpy
246
+ selection = numpy .empty (q .shape , dtype = numpy .bool )
247
+
248
+ uniques , inverse = numpy .unique (q , return_inverse = True )
247
249
248
250
# no possibility of exception from here on out (for rollback)
249
- self .entries += float (weights .sum ())
251
+ for i , x in enumerate (uniques ):
252
+ if x not in self .bins :
253
+ self .bins [x ] = self .value .zero ()
254
+
255
+ numpy .not_equal (inverse , i , selection )
256
+ subweights [:] = weights
257
+ subweights [selection ] = 0.0
258
+ self .bins [x ]._numpy (data , subweights , shape )
259
+
260
+ self .entries += float (newentries )
250
261
251
262
def _sparksql (self , jvm , converter ):
252
263
return converter .Categorize (self .quantity .asSparkSQL (), self .value ._sparksql (jvm , converter ))
0 commit comments