|
1 | 1 | import re |
2 | | -import types |
| 2 | +import cffi |
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 |
|
6 | | -from numba.cuda.testing import unittest, skip_on_cudasim, CUDATestCase |
7 | | -from numba import cuda, jit, float32, int32 |
| 6 | +from numba.cuda.testing import (skip_on_cudasim, test_data_dir, unittest, |
| 7 | + CUDATestCase) |
| 8 | +from numba import cuda, jit, float32, int32, types |
8 | 9 | from numba.core.errors import TypingError |
| 10 | +from types import ModuleType |
9 | 11 |
|
10 | 12 |
|
11 | 13 | class TestDeviceFunc(CUDATestCase): |
@@ -92,7 +94,7 @@ def test_cpu_dispatcher_other_module(self): |
92 | 94 | def add(a, b): |
93 | 95 | return a + b |
94 | 96 |
|
95 | | - mymod = types.ModuleType(name='mymod') |
| 97 | + mymod = ModuleType(name='mymod') |
96 | 98 | mymod.add = add |
97 | 99 | del add |
98 | 100 |
|
@@ -192,31 +194,128 @@ def rgba_caller(x, channels): |
192 | 194 |
|
193 | 195 | self.assertEqual(0x04010203, x[0]) |
194 | 196 |
|
195 | | - def _test_declare_device(self, decl): |
| 197 | + |
| 198 | +times2_cu = cuda.CUSource(""" |
| 199 | +extern "C" __device__ |
| 200 | +int times2(int *out, int a) |
| 201 | +{ |
| 202 | + *out = a * 2; |
| 203 | + return 0; |
| 204 | +} |
| 205 | +""") |
| 206 | + |
| 207 | + |
| 208 | +times4_cu = cuda.CUSource(""" |
| 209 | +extern "C" __device__ |
| 210 | +int times2(int *out, int a); |
| 211 | +
|
| 212 | +extern "C" __device__ |
| 213 | +int times4(int *out, int a) |
| 214 | +{ |
| 215 | + int tmp; |
| 216 | + times2(&tmp, a); |
| 217 | + *out = tmp * 2; |
| 218 | + return 0; |
| 219 | +} |
| 220 | +""") |
| 221 | + |
| 222 | +jitlink_user_cu = cuda.CUSource(""" |
| 223 | +extern "C" __device__ |
| 224 | +int array_mutator(void *out, int *a); |
| 225 | +
|
| 226 | +extern "C" __device__ |
| 227 | +int use_array_mutator(void *out, int *a) { |
| 228 | + array_mutator(out, a); |
| 229 | + return 0; |
| 230 | +} |
| 231 | +""") |
| 232 | + |
| 233 | + |
| 234 | +@skip_on_cudasim('External functions unsupported in the simulator') |
| 235 | +class TestDeclareDevice(CUDATestCase): |
| 236 | + |
| 237 | + def check_api(self, decl): |
196 | 238 | self.assertEqual(decl.name, 'f1') |
197 | 239 | self.assertEqual(decl.sig.args, (float32[:],)) |
198 | 240 | self.assertEqual(decl.sig.return_type, int32) |
199 | 241 |
|
200 | | - @skip_on_cudasim('cudasim does not check signatures') |
201 | 242 | def test_declare_device_signature(self): |
202 | 243 | f1 = cuda.declare_device('f1', int32(float32[:])) |
203 | | - self._test_declare_device(f1) |
| 244 | + self.check_api(f1) |
204 | 245 |
|
205 | | - @skip_on_cudasim('cudasim does not check signatures') |
206 | 246 | def test_declare_device_string(self): |
207 | 247 | f1 = cuda.declare_device('f1', 'int32(float32[:])') |
208 | | - self._test_declare_device(f1) |
| 248 | + self.check_api(f1) |
209 | 249 |
|
210 | | - @skip_on_cudasim('cudasim does not check signatures') |
211 | 250 | def test_bad_declare_device_tuple(self): |
212 | 251 | with self.assertRaisesRegex(TypeError, 'Return type'): |
213 | 252 | cuda.declare_device('f1', (float32[:],)) |
214 | 253 |
|
215 | | - @skip_on_cudasim('cudasim does not check signatures') |
216 | 254 | def test_bad_declare_device_string(self): |
217 | 255 | with self.assertRaisesRegex(TypeError, 'Return type'): |
218 | 256 | cuda.declare_device('f1', '(float32[:],)') |
219 | 257 |
|
| 258 | + def test_link_cu_source(self): |
| 259 | + times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu) |
| 260 | + |
| 261 | + @cuda.jit |
| 262 | + def kernel(r, x): |
| 263 | + i = cuda.grid(1) |
| 264 | + if i < len(r): |
| 265 | + r[i] = times2(x[i]) |
| 266 | + |
| 267 | + x = np.arange(10, dtype=np.int32) |
| 268 | + r = np.empty_like(x) |
| 269 | + |
| 270 | + kernel[1, 32](r, x) |
| 271 | + |
| 272 | + np.testing.assert_equal(r, x * 2) |
| 273 | + |
| 274 | + def _test_link_multiple_sources(self, link_type): |
| 275 | + link = link_type([times2_cu, times4_cu]) |
| 276 | + times4 = cuda.declare_device('times4', 'int32(int32)', link=link) |
| 277 | + |
| 278 | + @cuda.jit |
| 279 | + def kernel(r, x): |
| 280 | + i = cuda.grid(1) |
| 281 | + if i < len(r): |
| 282 | + r[i] = times4(x[i]) |
| 283 | + |
| 284 | + x = np.arange(10, dtype=np.int32) |
| 285 | + r = np.empty_like(x) |
| 286 | + |
| 287 | + kernel[1, 32](r, x) |
| 288 | + |
| 289 | + np.testing.assert_equal(r, x * 4) |
| 290 | + |
| 291 | + def test_link_multiple_sources_set(self): |
| 292 | + self._test_link_multiple_sources(set) |
| 293 | + |
| 294 | + def test_link_multiple_sources_tuple(self): |
| 295 | + self._test_link_multiple_sources(tuple) |
| 296 | + |
| 297 | + def test_link_multiple_sources_list(self): |
| 298 | + self._test_link_multiple_sources(list) |
| 299 | + |
| 300 | + def test_link_sources_in_memory_and_on_disk(self): |
| 301 | + jitlink_cu = str(test_data_dir / "jitlink.cu") |
| 302 | + link = [jitlink_cu, jitlink_user_cu] |
| 303 | + sig = types.void(types.CPointer(types.int32)) |
| 304 | + ext_fn = cuda.declare_device("use_array_mutator", sig, link=link) |
| 305 | + |
| 306 | + ffi = cffi.FFI() |
| 307 | + |
| 308 | + @cuda.jit |
| 309 | + def kernel(x): |
| 310 | + ptr = ffi.from_buffer(x) |
| 311 | + ext_fn(ptr) |
| 312 | + |
| 313 | + x = np.arange(2, dtype=np.int32) |
| 314 | + kernel[1, 1](x) |
| 315 | + |
| 316 | + expected = np.ones(2, dtype=np.int32) |
| 317 | + np.testing.assert_equal(x, expected) |
| 318 | + |
220 | 319 |
|
221 | 320 | if __name__ == '__main__': |
222 | 321 | unittest.main() |
0 commit comments