Skip to content

Commit a6ee40e

Browse files
authored
feat: add math.nextafter (#543)
1 parent b43dcc8 commit a6ee40e

File tree

5 files changed

+27
-1
lines changed

5 files changed

+27
-1
lines changed

docs/source/user/cudapysupported.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ The following functions from the :mod:`math` module are supported:
225225
* :func:`math.log2`
226226
* :func:`math.log10`
227227
* :func:`math.log1p`
228+
* :func:`math.nextafter` (Excluding the ``steps`` keyword argument)
228229
* :func:`math.sqrt`
229230
* :func:`math.remainder`
230231
* :func:`math.pow`

numba_cuda/numba/cuda/cudamath.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class Math_hypot(ConcreteTemplate):
8787

8888
@infer_global(math.copysign)
8989
@infer_global(math.fmod)
90+
@infer_global(math.nextafter)
9091
class Math_binary(ConcreteTemplate):
9192
cases = [
9293
signature(types.float32, types.float32, types.float32),

numba_cuda/numba/cuda/mathimpl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
binarys += [("fmod", "fmodf", math.fmod)]
6868
binarys += [("hypot", "hypotf", math.hypot)]
6969
binarys += [("remainder", "remainderf", math.remainder)]
70+
binarys += [("nextafter", "nextafterf", math.nextafter)]
7071

7172
binarys_fastmath = {}
7273
binarys_fastmath["powf"] = "fast_powf"

numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from numba import cuda
88
from numba.cuda import float32
99
from numba.cuda.compiler import compile_ptx_for_current_device, compile_ptx
10-
from math import cos, sin, tan, exp, log, log10, log2, pow, tanh
10+
from math import cos, sin, tan, exp, log, log10, log2, pow, tanh, nextafter
1111
from operator import truediv
1212
import numpy as np
1313
from numba.cuda.testing import CUDATestCase, skip_on_cudasim, skip_unless_cc_75
@@ -194,6 +194,15 @@ def test_powf(self):
194194
),
195195
)
196196

197+
def test_nextafterf(self):
198+
self._test_fast_math_binary(
199+
nextafter,
200+
FastMathCriterion(
201+
fast_expected=[".ftz.f32 "],
202+
prec_unexpected=[".ftz.f32 "],
203+
),
204+
)
205+
197206
def test_divf(self):
198207
self._test_fast_math_binary(
199208
truediv,

numba_cuda/numba/cuda/tests/cudapy/test_math.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def math_remainder(A, B, C):
146146
C[i] = math.remainder(A[i], B[i])
147147

148148

149+
def math_nextafter(A, B, C):
150+
i = cuda.grid(1)
151+
C[i] = math.nextafter(A[i], B[i])
152+
153+
149154
def math_sqrt(A, B):
150155
i = cuda.grid(1)
151156
B[i] = math.sqrt(A[i])
@@ -614,6 +619,15 @@ def test_0_0(r, x, y):
614619
test_0_0[1, 1](r, 0, 0)
615620
self.assertTrue(np.isnan(r[0]))
616621

622+
# ---------------------------------------------------------------------------
623+
# test_math_nextafter
624+
625+
def test_math_nextafter(self):
626+
self.binary_template_float32(math_nextafter, np.nextafter, start=1e-11)
627+
self.binary_template_float64(math_remainder, np.remainder, start=1e-11)
628+
self.binary_template_int64(math_remainder, np.remainder, start=1)
629+
self.binary_template_uint64(math_remainder, np.remainder, start=1)
630+
617631
# ---------------------------------------------------------------------------
618632
# test_math_sqrt
619633

0 commit comments

Comments
 (0)