Skip to content

For discussion: numba_scipy.stats #42

@luk-f-a

Description

@luk-f-a

hi everyone!

this is meant as a way to gather feedback on the current status of numba_scipy.stats. I'm pinging people that have expressed interest in numba_scipy.stats and/or are involved in numba and scipy. I'd like to share what I've learned so far, and hopefully you'll share your perspective on this.

Since last year I've been looking into scipy.stats with a view of getting numba-scipy.stats started. I created a prototype in #16. It's viable, but the experience has led me to question the cost/benefit tradeoff of following that path.

The main technical complication with scipy.stats is that is not function based, but object based. It relies on several of Python's OO features like inheritance and operator overloading. Numba has great support for function-based libraries (or method based, when the number of objects is limited) like Numpy. However, the support for classes (via jitclasses) is more limited and experimental. Outside of jitclasses, the only other option is to use the extending module, with the added effort that it implies.

The consequence of the above is that it will not be possible to fully imitate the behaviour of scipy.stats. At least not in the medium term, and not without a lot of work.

Even if jitclasses worked exactly as python classes, scipy.stats has more than a hundred distributions, each of them with more than 10 methods. If we followed the way of how numba supports numpy, we are talking about 1000+ methods to re-write. In some cases there will be performance improvements, but in some cases there won't.

Look at the following example:

from scipy.stats import norm
from numba import njit

def foo():
    k = 20000
    x = np.zeros(k)
    for m in range(100):
        x += norm.rvs(m, 1, size= 20000)
    return x

foo_jit = njit(foo)

@njit
def bar():
    k = 20000
    x = np.zeros(k)
    for m in range(100):
        with nb.objmode(y='float64[:]'):
            y = norm.rvs(m, 1, size= 20000)
        x += y
    return x

%timeit foo() #66 ms ± 277 µs

foo_jit()
%timeit foo_jit() #73.7 ms ± 214 µs

bar()
%timeit bar() #65.8 ms ± 208 µs

There's no performance improvement at all, because most of the work is already done in C. This will be the case in many scipy.stats functions.

To summarize, I see a few ways forward, each with pros and cons:

  • jitclass based solution

    • pros: easy for people to contribute (not much more than being competent with python and having used numba before)
    • cons: won't replicate scipy's behaviour, will regularly find jitclass' limitations and will have to find workarounds, will require 1000s of man-hours to build. All that effort does not build anything new, just a copy of existing scipy features.
  • low-level numba extension (http://numba.pydata.org/numba-doc/latest/extending/low-level.html)

    • pros: should be able to reproduce all or most behaviour
    • cons: harder to work with: would increase the effort required and limit the number of contributors. All that effort does not build anything new, just a copy of existing scipy features.
  • objmode approach = no jitted solution

I personally lean towards option 3 at the moment. I might write some custom code that calls special functions if I really need performance. But I'm not feeling very attracted to the idea of re-implementing such a large module as scipy.stats.

It would be great to hear your perspective on this.

cc: @gioxc88 @francoislauger @stnatter @remidebette @rileymcdowell @person142 @stuartarchibald

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions