Skip to content

Commit 363c82d

Browse files
committed
feat: add math nextafter
1 parent 2567b28 commit 363c82d

File tree

5 files changed

+28
-1
lines changed

5 files changed

+28
-1
lines changed

docs/source/user/cudapysupported.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ The following functions from the :mod:`math` module are supported:
224224
* :func:`math.log2`
225225
* :func:`math.log10`
226226
* :func:`math.log1p`
227+
* :func:`math.nextafter`
227228
* :func:`math.sqrt`
228229
* :func:`math.remainder`
229230
* :func:`math.pow`

numba_cuda/numba/cuda/cudamath.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class Math_hypot(ConcreteTemplate):
8282

8383
@infer_global(math.copysign)
8484
@infer_global(math.fmod)
85+
@infer_global(math.nextafter)
8586
class Math_binary(ConcreteTemplate):
8687
cases = [
8788
signature(types.float32, types.float32, types.float32),

numba_cuda/numba/cuda/mathimpl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@
6464
binarys += [("fmod", "fmodf", math.fmod)]
6565
binarys += [("hypot", "hypotf", math.hypot)]
6666
binarys += [("remainder", "remainderf", math.remainder)]
67+
binarys += [("nextafter", "nextafterf", math.nextafter)]
6768

6869
binarys_fastmath = {}
6970
binarys_fastmath["powf"] = "fast_powf"
71+
binarys_fastmath["nextafterf"] = "fast_nextafterf"
7072

7173

7274
@lower(math.isinf, types.Integer)

numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66
from numba import cuda, float32
77
from numba.cuda.compiler import compile_ptx_for_current_device, compile_ptx
8-
from math import cos, sin, tan, exp, log, log10, log2, pow, tanh
8+
from math import cos, sin, tan, exp, log, log10, log2, pow, tanh, nextafter
99
from operator import truediv
1010
import numpy as np
1111
from numba.cuda.testing import CUDATestCase, skip_on_cudasim, skip_unless_cc_75
@@ -179,6 +179,15 @@ def test_powf(self):
179179
),
180180
)
181181

182+
def test_nextafterf(self):
183+
self._test_fast_math_binary(
184+
nextafter,
185+
FastMathCriterion(
186+
fast_expected=["lg2.approx.ftz.f32 "], # FIX
187+
prec_unexpected=["lg2.approx.ftz.f32 "], # FIX
188+
),
189+
)
190+
182191
def test_divf(self):
183192
self._test_fast_math_binary(
184193
truediv,

numba_cuda/numba/cuda/tests/cudapy/test_math.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ def math_remainder(A, B, C):
138138
C[i] = math.remainder(A[i], B[i])
139139

140140

141+
def math_nextafter(A, B, C):
142+
i = cuda.grid(1)
143+
C[i] = math.nextafter(A[i], B[i])
144+
145+
141146
def math_sqrt(A, B):
142147
i = cuda.grid(1)
143148
B[i] = math.sqrt(A[i])
@@ -594,6 +599,15 @@ def test_0_0(r, x, y):
594599
test_0_0[1, 1](r, 0, 0)
595600
self.assertTrue(np.isnan(r[0]))
596601

602+
# ---------------------------------------------------------------------------
603+
# test_math_nextafter
604+
605+
def test_math_nextafter(self):
606+
self.binary_template_float32(math_nextafter, np.nextafter, start=1e-11)
607+
self.binary_template_float64(math_remainder, np.remainder, start=1e-11)
608+
self.binary_template_int64(math_remainder, np.remainder, start=1)
609+
self.binary_template_uint64(math_remainder, np.remainder, start=1)
610+
597611
# ---------------------------------------------------------------------------
598612
# test_math_sqrt
599613

0 commit comments

Comments
 (0)