-
Notifications
You must be signed in to change notification settings - Fork 39
Description
hi everyone!
this is meant as a way to gather feedback on the current status of numba_scipy.stats. I'm pinging people that have expressed interest in numba_scipy.stats and/or are involved in numba and scipy. I'd like to share what I've learned so far, and hopefully you'll share your perspective on this.
Since last year I've been looking into scipy.stats with a view of getting numba-scipy.stats started. I created a prototype in #16. It's viable, but the experience has led me to question the cost/benefit tradeoff of following that path.
The main technical complication with scipy.stats is that is not function based, but object based. It relies on several of Python's OO features like inheritance and operator overloading. Numba has great support for function-based libraries (or method based, when the number of objects is limited) like Numpy. However, the support for classes (via jitclasses) is more limited and experimental. Outside of jitclasses, the only other option is to use the extending module, with the added effort that it implies.
The consequence of the above is that it will not be possible to fully imitate the behaviour of scipy.stats. At least not in the medium term, and not without a lot of work.
Even if jitclasses worked exactly as python classes, scipy.stats has more than a hundred distributions, each of them with more than 10 methods. If we followed the way of how numba supports numpy, we are talking about 1000+ methods to re-write. In some cases there will be performance improvements, but in some cases there won't.
Look at the following example:
from scipy.stats import norm
from numba import njit
def foo():
k = 20000
x = np.zeros(k)
for m in range(100):
x += norm.rvs(m, 1, size= 20000)
return x
foo_jit = njit(foo)
@njit
def bar():
k = 20000
x = np.zeros(k)
for m in range(100):
with nb.objmode(y='float64[:]'):
y = norm.rvs(m, 1, size= 20000)
x += y
return x
%timeit foo() #66 ms ± 277 µs
foo_jit()
%timeit foo_jit() #73.7 ms ± 214 µs
bar()
%timeit bar() #65.8 ms ± 208 µsThere's no performance improvement at all, because most of the work is already done in C. This will be the case in many scipy.stats functions.
To summarize, I see a few ways forward, each with pros and cons:
-
jitclass based solution
- pros: easy for people to contribute (not much more than being competent with python and having used numba before)
- cons: won't replicate scipy's behaviour, will regularly find jitclass' limitations and will have to find workarounds, will require 1000s of man-hours to build. All that effort does not build anything new, just a copy of existing scipy features.
-
low-level numba extension (http://numba.pydata.org/numba-doc/latest/extending/low-level.html)
- pros: should be able to reproduce all or most behaviour
- cons: harder to work with: would increase the effort required and limit the number of contributors. All that effort does not build anything new, just a copy of existing scipy features.
-
objmodeapproach = no jitted solution- pros: all existing features are immediately supported.
- cons: all currently slow methods remain slow. added overhead of entering
objmode, both in runtime and in boilerplate code. This last point might be made lighter by these: Calling objectmode function from nopython mode numba#5461 and Pass thru pyobjects numba#3282
I personally lean towards option 3 at the moment. I might write some custom code that calls special functions if I really need performance. But I'm not feeling very attracted to the idea of re-implementing such a large module as scipy.stats.
It would be great to hear your perspective on this.
cc: @gioxc88 @francoislauger @stnatter @remidebette @rileymcdowell @person142 @stuartarchibald