diff --git a/.github/workflows/test.sh b/.github/workflows/test.sh index 53c517b9..98fc73bb 100755 --- a/.github/workflows/test.sh +++ b/.github/workflows/test.sh @@ -30,7 +30,7 @@ else (( err += $? )) # CUDA, HIP, or SYCL. These fail gracefully when GPUs are absent. - ./run_tests.py ${args} --blas1-device --blas3-device + ./run_tests.py ${args} --blas1-device --blas2-device --blas3-device (( err += $? )) ./run_tests.py ${args} --batch-blas3-device diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e571559..3fe823e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -391,6 +391,7 @@ if (CUDAToolkit_FOUND) blaspp PRIVATE src/cuda/device_shift_vec.cu + src/cuda/device_conj.cu ) # Some platforms need these to be public libraries. target_link_libraries( @@ -400,6 +401,7 @@ elseif (rocblas_FOUND) blaspp PRIVATE src/hip/device_shift_vec.hip + src/hip/device_conj.hip ) # Some platforms need these to be public libraries. target_link_libraries( diff --git a/config/config.py b/config/config.py index 6df2aae6..ee670864 100644 --- a/config/config.py +++ b/config/config.py @@ -715,16 +715,18 @@ def sycl_onemkl_library(): Does not actually run the resulting exe, to allow compiling on a machine without GPUs. ''' - libs = '-lmkl_sycl -lsycl -lOpenCL' + ldflags = '-fsycl' + libs = '-lmkl_sycl -lsycl -lOpenCL' print_subhead( 'SYCL and oneMKL libraries' ) - print_test( ' ' + libs ) + print_test( ' ' + ldflags + ' ' + libs ) # Intel compiler vars.sh defines $CMPLR_ROOT root = environ['CMPLR_ROOT'] or environ['CMPROOT'] inc = '' if (root): inc = '-I' + root + '/linux/include ' # space at end for concat - env = {'LIBS': libs, + env = {'LDFLAGS': ldflags, + 'LIBS': libs, 'CXXFLAGS': inc + define('HAVE_SYCL') + ' -fsycl -Wno-deprecated-declarations'} (rc, out, err) = compile_exe( 'config/onemkl.cc', env ) diff --git a/include/blas/counter.hh b/include/blas/counter.hh index 4cdf18db..1c2afbd7 100644 --- a/include/blas/counter.hh +++ b/include/blas/counter.hh @@ -75,19 +75,42 @@ public: trmm, trsm, - // Device BLAS + // Level 1 BLAS + dev_asum, + dev_axpy, dev_copy, dev_dot, - dev_gemm, - dev_hemm, - dev_her2k, - dev_herk, + dev_dotu, + dev_iamax, dev_nrm2, + dev_rot, + dev_rotg, + dev_rotm, + dev_rotmg, dev_scal, dev_swap, + + // Level 2 BLAS + dev_gemv, + dev_ger, + dev_geru, + dev_hemv, + dev_her, + dev_her2, + dev_symv, + dev_syr, + dev_syr2, + dev_trmv, + dev_trsv, + + // Level 3 BLAS + dev_gemm, + dev_hemm, + dev_herk, + dev_her2k, dev_symm, - dev_syr2k, dev_syrk, + dev_syr2k, dev_trmm, dev_trsm, @@ -197,12 +220,38 @@ public: //============================================================================== // Device BLAS + typedef axpy_type dev_axpy_type; + typedef axpy_type dev_scal_type; typedef axpy_type dev_copy_type; + typedef axpy_type dev_swap_type; typedef axpy_type dev_dot_type; + typedef axpy_type dev_dotu_type; typedef axpy_type dev_nrm2_type; - typedef axpy_type dev_scal_type; - typedef axpy_type dev_swap_type; + typedef axpy_type dev_asum_type; + typedef axpy_type dev_iamax_type; + typedef axpy_type dev_rot_type; + typedef axpy_type dev_rotm_type; + typedef axpy_type dev_rotg_type; + typedef axpy_type dev_rotmg_type; + + //------------------------------------------------------------------------------ + typedef gemv_type dev_gemv_type; + + typedef hemv_type dev_hemv_type; + typedef hemv_type dev_symv_type; + typedef hemv_type dev_her_type; + typedef hemv_type dev_her2_type; + typedef hemv_type dev_syr_type; + typedef hemv_type dev_syr2_type; + + typedef trmv_type dev_trmv_type; + typedef trmv_type dev_trsv_type; + typedef ger_type dev_ger_type; + typedef ger_type dev_geru_type; + typedef ger_type dev_gerc_type; + + //------------------------------------------------------------------------------ typedef gemm_type dev_gemm_type; typedef hemm_type dev_hemm_type; @@ -405,7 +454,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::gemv( ptr->m, ptr->n ) * 1e9 * iter->count; printf( "gemv( %c, %lld, %lld ) count %d, flop count %.2e\n", - op2char( ptr->trans ), llong( ptr->m ), llong( ptr->n ), + to_char( ptr->trans ), llong( ptr->m ), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; @@ -414,7 +463,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::hemv( ptr->n ) * 1e9 * iter->count; printf( "hemv( %c, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); totalflops += flop; break; } @@ -422,7 +471,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::symv( ptr->n ) * 1e9 * iter->count; printf( "symv( %c, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); totalflops += flop; break; } @@ -430,8 +479,8 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::trmv( ptr->n ) * 1e9 * iter->count; printf( "trmv( %c, %c, %c, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), - diag2char( ptr->diag), llong( ptr->n ), iter->count, flop ); + to_char( ptr->uplo ), to_char( ptr->trans ), + to_char( ptr->diag), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; } @@ -439,8 +488,8 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::trsv( ptr->n ) * 1e9 * iter->count; printf( "trsv( %c, %c, %c, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), - diag2char( ptr->diag), llong( ptr->n ), iter->count, flop ); + to_char( ptr->uplo ), to_char( ptr->trans ), + to_char( ptr->diag), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; } @@ -464,7 +513,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::her( ptr->n ) * 1e9 * iter->count; printf( "her( %c, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); totalflops += flop; break; } @@ -472,7 +521,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::her2( ptr->n ) * 1e9 * iter->count; printf( "her2( %c, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); totalflops += flop; break; } @@ -480,7 +529,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::syr( ptr->n ) * 1e9 * iter->count; printf( "syr( %c, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); totalflops += flop; break; } @@ -488,7 +537,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::syr2( ptr->n ) * 1e9 * iter->count; printf( "syr2( %c, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); totalflops += flop; break; } @@ -498,7 +547,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::gemm( ptr->m, ptr->n, ptr->k ) * 1e9 * iter->count; printf( "gemm( %c, %c, %lld, %lld, %lld ) count %d, flop count %.2e\n", - op2char( ptr->transA ), op2char( ptr->transB ), + to_char( ptr->transA ), to_char( ptr->transB ), llong( ptr->m ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; @@ -508,7 +557,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::hemm( ptr->side, ptr->m, ptr->n ) * 1e9 * iter->count; printf( "hemm( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - side2char( ptr->side ), uplo2char( ptr->uplo ), + to_char( ptr->side ), to_char( ptr->uplo ), llong( ptr->m ), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; @@ -517,7 +566,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::her2k( ptr->n, ptr->k ) * 1e9 * iter->count; printf( "her2k( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), + to_char( ptr->uplo ), to_char( ptr->trans ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; break; @@ -526,7 +575,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::herk( ptr->n, ptr->k ) * 1e9 * iter->count; printf( "herk( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), + to_char( ptr->uplo ), to_char( ptr->trans ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; break; @@ -535,7 +584,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::symm( ptr->side, ptr->m, ptr->n ) * 1e9 * iter->count; printf( "symm( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - side2char( ptr->side ), uplo2char( ptr->uplo ), + to_char( ptr->side ), to_char( ptr->uplo ), llong( ptr->m ), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; @@ -544,7 +593,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::syr2k( ptr->n, ptr->k ) * 1e9 * iter->count; printf( "syr2k( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), + to_char( ptr->uplo ), to_char( ptr->trans ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; break; @@ -553,7 +602,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::syrk( ptr->n, ptr->k ) * 1e9 * iter->count; printf( "syrk( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), + to_char( ptr->uplo ), to_char( ptr->trans ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; break; @@ -562,8 +611,8 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::trmm( ptr->side, ptr->m, ptr->n ) * 1e9 * iter->count; printf( "trmm( %c, %c, %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - side2char( ptr->side ), uplo2char( ptr->uplo ), - op2char( ptr->transA ), diag2char( ptr->diag ), + to_char( ptr->side ), to_char( ptr->uplo ), + to_char( ptr->transA ), to_char( ptr->diag ), llong( ptr->m ), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; @@ -572,14 +621,30 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::trsm( ptr->side, ptr->m, ptr->n ) * 1e9 * iter->count; printf( "trsm( %c, %c, %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - side2char( ptr->side ), uplo2char( ptr->uplo ), - op2char( ptr->transA ), diag2char( ptr->diag ), + to_char( ptr->side ), to_char( ptr->uplo ), + to_char( ptr->transA ), to_char( ptr->diag ), llong( ptr->m ), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; } - // Device BLAS + // Level 1 Device BLAS + case Id::dev_axpy: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::axpy( ptr->n ) * 1e9 * iter->count; + printf( "dev_axpy( %lld ) count %d, flop count %.2e\n", + llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_scal: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::scal( ptr->n ) * 1e9 * iter->count; + printf( "dev_scal( %lld ) count %d, flop count %.2e\n", + llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } case Id::dev_copy: { auto *ptr = static_cast( iter->ptr ); double flop = Gflop::copy( ptr->n ) * 1e9 * iter->count; @@ -588,6 +653,14 @@ public: totalflops += flop; break; } + case Id::dev_swap: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::swap( ptr->n ) * 1e9 * iter->count; + printf( "dev_swap( %lld ) count %d, flop count %.2e\n", + llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } case Id::dev_dot: { auto *ptr = static_cast( iter->ptr ); double flop = Gflop::dot( ptr->n ) * 1e9 * iter->count; @@ -596,11 +669,168 @@ public: totalflops += flop; break; } + case Id::dev_dotu: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::dot( ptr->n ) * 1e9 * iter->count; + printf( "dev_dotu( %lld ) count %d, flop count %.2e\n", + llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_nrm2: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::nrm2( ptr->n ) * 1e9 * iter->count; + printf( "dev_nrm2( %lld ) count %d, flop count %.2e\n", + llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_asum: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::asum( ptr->n ) * 1e9 * iter->count; + printf( "dev_asum( %lld ) count %d, flop count %.2e\n", + llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_iamax: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::iamax( ptr->n ) * 1e9 * iter->count; + printf( "dev_iamax( %lld ) count %d, flop count %.2e\n", + llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_rotg: { + // auto *ptr = static_cast( iter->ptr ); + // double flop = Gflop::rotg( ptr->n ) * 1e9; + printf( "dev_rotg( ) count %d\n", iter->count ); + // totalflops += flop; + break; + } + case Id::dev_rot: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::rot( ptr->n ) * 1e9 * iter->count; + printf( "dev_rot( %lld ) count %d, flop count %.2e\n", + llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_rotmg: { + // auto *ptr = static_cast( iter->ptr ); + // double flop = Gflop::rotmg( ptr->n ) * 1e9; + printf( "dev_rotmg( ) count %d\n", iter->count ); + // totalflops += flop; + break; + } + case Id::dev_rotm: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::rotm( ptr->n ) * 1e9 * iter->count; + printf( "dev_rotm( %lld ) count %d, flop count %.2e\n", + llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + + // Level 2 Device BLAS + case Id::dev_gemv: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::gemv( ptr->m, ptr->n ) * 1e9 * iter->count; + printf( "dev_gemv( %c, %lld, %lld ) count %d, flop count %.2e\n", + to_char( ptr->trans ), llong( ptr->m ), llong( ptr->n ), + iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_hemv: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::hemv( ptr->n ) * 1e9 * iter->count; + printf( "dev_hemv( %c, %lld ) count %d, flop count %.2e\n", + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_symv: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::symv( ptr->n ) * 1e9 * iter->count; + printf( "dev_symv( %c, %lld ) count %d, flop count %.2e\n", + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_trmv: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::trmv( ptr->n ) * 1e9 * iter->count; + printf( "dev_trmv( %c, %c, %c, %lld ) count %d, flop count %.2e\n", + to_char( ptr->uplo ), to_char( ptr->trans ), + to_char( ptr->diag), llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_trsv: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::trsv( ptr->n ) * 1e9 * iter->count; + printf( "dev_trsv( %c, %c, %c, %lld ) count %d, flop count %.2e\n", + to_char( ptr->uplo ), to_char( ptr->trans ), + to_char( ptr->diag), llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_ger: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::ger( ptr->m, ptr->n ) * 1e9 * iter->count; + printf( "dev_ger( %lld, %lld ) count %d, flop count %.2e\n", + llong( ptr->m ), llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_geru: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::ger( ptr->m, ptr->n ) * 1e9 * iter->count; + printf( "dev_geru( %lld, %lld ) count %d, flop count %.2e\n", + llong( ptr->m ), llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_her: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::her( ptr->n ) * 1e9 * iter->count; + printf( "dev_her( %c, %lld ) count %d, flop count %.2e\n", + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_her2: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::her2( ptr->n ) * 1e9 * iter->count; + printf( "dev_her2( %c, %lld ) count %d, flop count %.2e\n", + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_syr: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::syr( ptr->n ) * 1e9 * iter->count; + printf( "dev_syr( %c, %lld ) count %d, flop count %.2e\n", + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + case Id::dev_syr2: { + auto *ptr = static_cast( iter->ptr ); + double flop = Gflop::syr2( ptr->n ) * 1e9 * iter->count; + printf( "dev_syr2( %c, %lld ) count %d, flop count %.2e\n", + to_char( ptr->uplo ),llong( ptr->n ), iter->count, flop ); + totalflops += flop; + break; + } + + // Level 3 Device BLAS case Id::dev_gemm: { auto *ptr = static_cast( iter->ptr ); double flop = Gflop::gemm( ptr->m, ptr->n, ptr->k ) * 1e9 * iter->count; printf( "dev_gemm( %c, %c, %lld, %lld, %lld ) count %d, flop count %.2e\n", - op2char( ptr->transA ), op2char( ptr->transB ), + to_char( ptr->transA ), to_char( ptr->transB ), llong( ptr->m ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; @@ -610,7 +840,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::hemm( ptr->side, ptr->m, ptr->n ) * 1e9 * iter->count; printf( "dev_hemm( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - side2char( ptr->side ), uplo2char( ptr->uplo ), + to_char( ptr->side ), to_char( ptr->uplo ), llong( ptr->m ), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; @@ -619,7 +849,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::her2k( ptr->n, ptr->k ) * 1e9 * iter->count; printf( "dev_her2k( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), + to_char( ptr->uplo ), to_char( ptr->trans ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; break; @@ -628,40 +858,16 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::herk( ptr->n, ptr->k ) * 1e9 * iter->count; printf( "dev_herk( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), + to_char( ptr->uplo ), to_char( ptr->trans ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; break; } - case Id::dev_nrm2: { - auto *ptr = static_cast( iter->ptr ); - double flop = Gflop::nrm2( ptr->n ) * 1e9 * iter->count; - printf( "dev_nrm2( %lld ) count %d, flop count %.2e\n", - llong( ptr->n ), iter->count, flop ); - totalflops += flop; - break; - } - case Id::dev_scal: { - auto *ptr = static_cast( iter->ptr ); - double flop = Gflop::scal( ptr->n ) * 1e9 * iter->count; - printf( "dev_scal( %lld ) count %d, flop count %.2e\n", - llong( ptr->n ), iter->count, flop ); - totalflops += flop; - break; - } - case Id::dev_swap: { - auto *ptr = static_cast( iter->ptr ); - double flop = Gflop::swap( ptr->n ) * 1e9 * iter->count; - printf( "dev_swap( %lld ) count %d, flop count %.2e\n", - llong( ptr->n ), iter->count, flop ); - totalflops += flop; - break; - } case Id::dev_symm: { auto *ptr = static_cast( iter->ptr ); double flop = Gflop::symm( ptr->side, ptr->m, ptr->n ) * 1e9 * iter->count; printf( "dev_symm( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - side2char( ptr->side ), uplo2char( ptr->uplo ), + to_char( ptr->side ), to_char( ptr->uplo ), llong( ptr->m ), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; @@ -670,7 +876,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::syr2k( ptr->n, ptr->k ) * 1e9 * iter->count; printf( "dev_syr2k( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), + to_char( ptr->uplo ), to_char( ptr->trans ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; break; @@ -679,7 +885,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::syrk( ptr->n, ptr->k ) * 1e9 * iter->count; printf( "dev_syrk( %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - uplo2char( ptr->uplo ), op2char( ptr->trans ), + to_char( ptr->uplo ), to_char( ptr->trans ), llong( ptr->n ), llong( ptr->k ), iter->count, flop ); totalflops += flop; break; @@ -688,8 +894,8 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::trmm( ptr->side, ptr->m, ptr->n ) * 1e9 * iter->count; printf( "dev_trmm( %c, %c, %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - side2char( ptr->side ), uplo2char( ptr->uplo ), - op2char( ptr->transA ), diag2char( ptr->diag ), + to_char( ptr->side ), to_char( ptr->uplo ), + to_char( ptr->transA ), to_char( ptr->diag ), llong( ptr->m ), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; @@ -698,8 +904,8 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::trsm( ptr->side, ptr->m, ptr->n ) * 1e9 * iter->count; printf( "dev_trsm( %c, %c, %c, %c, %lld, %lld ) count %d, flop count %.2e\n", - side2char( ptr->side ), uplo2char( ptr->uplo ), - op2char( ptr->transA ), diag2char( ptr->diag ), + to_char( ptr->side ), to_char( ptr->uplo ), + to_char( ptr->transA ), to_char( ptr->diag ), llong( ptr->m ), llong( ptr->n ), iter->count, flop ); totalflops += flop; break; @@ -710,7 +916,7 @@ public: auto *ptr = static_cast( iter->ptr ); double flop = Gflop::gemm( ptr->m, ptr->n, ptr->k ) * 1e9 * iter->count; printf( "dev_batch_gemm( %c, %c, %lld, %lld, %lld, %lld ) count %d, flop count %.2e\n", - op2char( ptr->transA ), op2char( ptr->transB ), + to_char( ptr->transA ), to_char( ptr->transB ), llong( ptr->m ), llong( ptr->n ), llong( ptr->k ), llong( ptr->batch_size ), iter->count, flop ); totalflops += flop; diff --git a/include/blas/device.hh b/include/blas/device.hh index 5118b1a0..3ce7d95b 100644 --- a/include/blas/device.hh +++ b/include/blas/device.hh @@ -27,6 +27,7 @@ #endif #include + #include // Headers moved in ROCm 5.2 #if HIP_VERSION >= 50200000 @@ -699,13 +700,43 @@ void Queue::work_ensure_size( size_t lwork ) } } -//------------------------------------------------------------------------------ -/// Add a constant c to an n-element vector v. -/// - template void shift_vec( int64_t n, scalar_t* v, scalar_t c, blas::Queue& queue ); +template +void conj( + int64_t n, + TS const* src, int64_t inc_src, + TD* dst, int64_t inc_dst, + blas::Queue& queue ); + +#if defined(BLAS_HAVE_SYCL) + +template +void conj( + int64_t n, + TS const* src, int64_t inc_src, + TD* dst, int64_t inc_dst, + blas::Queue& queue ) +{ + using std::conj; + + if (n <= 0) { + return; + } + + int64_t i_src = (inc_src > 0 ? 0 : (1 - n) * inc_src); + int64_t i_dst = (inc_dst > 0 ? 0 : (1 - n) * inc_dst); + + queue.stream().submit( [&]( sycl::handler& h ) { + h.parallel_for( sycl::range<1>(n), [=]( sycl::id<1> i ) { + dst[ i*inc_dst + i_dst ] = conj( src[ i*inc_src + i_src ] ); + } ); + } ); +} + +#endif // BLAS_HAVE_SYCL + } // namespace blas #endif // #ifndef BLAS_DEVICE_HH diff --git a/include/blas/util.hh b/include/blas/util.hh index 4adabb99..1fccf661 100644 --- a/include/blas/util.hh +++ b/include/blas/util.hh @@ -640,6 +640,16 @@ inline void abort_if( bool cond, const char* func, const char* format, ... ) #endif +//------------------------------------------------------------------------------ +/// Integer division rounding up instead of down +/// @return ceil( x / y ), for integer types T1, T2. +template +inline constexpr std::common_type_t ceildiv( T1 x, T2 y ) +{ + using T = std::common_type_t; + return T((x + y - 1) / y); +} + } // namespace blas #endif // #ifndef BLAS_UTIL_HH diff --git a/src/cuda/device_conj.cu b/src/cuda/device_conj.cu new file mode 100644 index 00000000..1711ec6b --- /dev/null +++ b/src/cuda/device_conj.cu @@ -0,0 +1,104 @@ +#include "blas/device.hh" +#include "thrust/complex.h" + +#if defined(BLAS_HAVE_CUBLAS) + +namespace blas { + +__device__ std::complex conj_convert( + std::complex z) +{ + ((cuComplex*) &z)->y *= -1; + return z; +} + +__device__ std::complex conj_convert( + std::complex z) +{ + ((cuDoubleComplex*) &z)->y *= -1; + return z; +} + +// Each thread conjugates 1 item +template +__global__ void conj_kernel( + int64_t n, + TS const* src, int64_t inc_src, int64_t i_src, + TD* dst, int64_t inc_dst, int64_t i_dst) +{ + using thrust::conj; + + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) + dst[ i*inc_dst + i_dst ] = conj_convert( src[ i*inc_src + i_src ] ); +} + +//------------------------------------------------------------------------------ +/// Conjugates each element of the vector src and stores in dst. +/// +/// @param[in] n +/// Number of elements in the vector. n >= 0. +/// +/// @param[in] src +/// Pointer to the input vector of length n. +/// +/// @param[in] inc_src +/// Stride between elements of src. inc_src >= 1. +/// +/// @param[out] dst +/// Pointer to output vector +/// On exit, each element dst[i] is updated as dst[i] = conj( src[i] ). +/// dst may be the same as src. +/// +/// @param[in] inc_dst +/// Stride between elements of dst. inc_dst >= 1. +/// +/// @param[in] queue +/// BLAS++ queue to execute in. +/// +template +void conj( + int64_t n, + TS const* src, int64_t inc_src, + TD* dst, int64_t inc_dst, + blas::Queue& queue ) +{ + if (n <= 0) { + return; + } + + const int64_t BlockSize = 128; + + int64_t n_threads = min( BlockSize, n ); + int64_t n_blocks = ceildiv(n, n_threads); + + int64_t i_src = (inc_src > 0 ? 0 : (1 - n) * inc_src); + int64_t i_dst = (inc_dst > 0 ? 0 : (1 - n) * inc_dst); + + blas_dev_call( + cudaSetDevice( queue.device() ) ); + + conj_kernel<<>>( + n, src, inc_src, i_src, dst, inc_dst, i_dst ); + + blas_dev_call( + cudaGetLastError() ); +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template void conj( + int64_t n, + std::complex const* src, int64_t inc_src, + std::complex* dst, int64_t inc_dst, + blas::Queue& queue); + +template void conj( + int64_t n, + std::complex const* src, int64_t inc_src, + std::complex* dst, int64_t inc_dst, + blas::Queue& queue); + +} // namespace blas + +#endif // BLAS_HAVE_CUBLAS diff --git a/src/device_asum.cc b/src/device_asum.cc index 3852b819..67650fc4 100644 --- a/src/device_asum.cc +++ b/src/device_asum.cc @@ -4,6 +4,7 @@ // the terms of the BSD 3-Clause license. See the accompanying LICENSE file. #include "blas/device_blas.hh" +#include "blas/counter.hh" #include "device_internal.hh" diff --git a/src/device_axpy.cc b/src/device_axpy.cc index a12ec30b..3fae9377 100644 --- a/src/device_axpy.cc +++ b/src/device_axpy.cc @@ -4,6 +4,7 @@ // the terms of the BSD 3-Clause license. See the accompanying LICENSE file. #include "blas/device_blas.hh" +#include "blas/counter.hh" #include "device_internal.hh" diff --git a/src/device_batch_gemm.cc b/src/device_batch_gemm.cc index a033bd07..6343ad5c 100644 --- a/src/device_batch_gemm.cc +++ b/src/device_batch_gemm.cc @@ -91,7 +91,7 @@ void gemm( element = { transA_, transB_, m_, n_, k_, batch_size }; counter::insert( element, counter::Id::dev_batch_gemm ); - double gflops = 1e9 * blas::Gflop< scalar_t >::gemm( m, n, k ); + double gflops = 1e9 * blas::Gflop< scalar_t >::gemm( m_, n_, k_ ); counter::inc_flop_count( (long long int)gflops ); #endif diff --git a/src/device_batch_hemm.cc b/src/device_batch_hemm.cc index f3edf295..451f0abd 100644 --- a/src/device_batch_hemm.cc +++ b/src/device_batch_hemm.cc @@ -65,7 +65,7 @@ void hemm( element = { batch_size }; counter::insert( element, counter::Id::dev_batch_hemm ); - double gflops = 1e9 * blas::Gflop< scalar_t >::hemm( side, m, n ); + double gflops = 1e9 * blas::Gflop< scalar_t >::hemm( side[0], m[0], n[0] ); counter::inc_flop_count( (long long int)gflops ); #endif diff --git a/src/device_iamax.cc b/src/device_iamax.cc index c0040426..d8bc02be 100644 --- a/src/device_iamax.cc +++ b/src/device_iamax.cc @@ -4,6 +4,7 @@ // the terms of the BSD 3-Clause license. See the accompanying LICENSE file. #include "blas/device_blas.hh" +#include "blas/counter.hh" #include "device_internal.hh" diff --git a/src/hip/device_conj.hip b/src/hip/device_conj.hip new file mode 100644 index 00000000..c970d153 --- /dev/null +++ b/src/hip/device_conj.hip @@ -0,0 +1,78 @@ +#include "blas/device.hh" +#include + +#if defined(BLAS_HAVE_ROCBLAS) + +namespace blas { + +__device__ std::complex conj_convert( + std::complex z) +{ + hipFloatComplex res = hipConjf(*(hipFloatComplex*) &z); + return *(std::complex*) &res; +} + +__device__ std::complex conj_convert( + std::complex z) +{ + hipDoubleComplex res = hipConj(*(hipDoubleComplex*) &z); + return *(std::complex*) &res; +} + +template +__global__ void conj_kernel( + int64_t n, + TS const* src, int64_t inc_src, int64_t i_src, + TD* dst, int64_t inc_dst, int64_t i_dst) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) + dst[ i*inc_dst + i_dst ] = conj_convert( src[ i*inc_src + i_src ] ); +} + +template +void conj( + int64_t n, + TS const* src, int64_t inc_src, + TD* dst, int64_t inc_dst, + blas::Queue& queue ) +{ + if (n <= 0) { + return; + } + + const int BlockSize = 128; + + int64_t n_threads = std::min( int64_t( BlockSize ), n ); + int64_t n_blocks = ceildiv(n, n_threads); + + int64_t i_src = (inc_src > 0 ? 0 : (1 - n) * inc_src); + int64_t i_dst = (inc_dst > 0 ? 0 : (1 - n) * inc_dst); + + blas_dev_call( + hipSetDevice( queue.device() ) ); + + conj_kernel<<>>( + n, src, inc_src, i_src, dst, inc_dst, i_dst ); + + blas_dev_call( + hipGetLastError() ); +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template void conj( + int64_t n, + std::complex const* src, int64_t inc_src, + std::complex* dst, int64_t inc_dst, + blas::Queue& queue); + +template void conj( + int64_t n, + std::complex const* src, int64_t inc_src, + std::complex* dst, int64_t inc_dst, + blas::Queue& queue); + +} // namespace blas + +#endif // BLAS_HAVE_ROCBLAS diff --git a/test/run_tests.py b/test/run_tests.py index d6da0f52..e83b003e 100755 --- a/test/run_tests.py +++ b/test/run_tests.py @@ -66,6 +66,7 @@ group_cat.add_argument( '--host', action='store_true', help='run all CPU host routines' ), group_cat.add_argument( '--blas1-device', action='store_true', help='run Level 1 BLAS on devices (GPUs)' ), + group_cat.add_argument( '--blas2-device', action='store_true', help='run Level 2 BLAS on devices (GPUs)' ), group_cat.add_argument( '--blas3-device', action='store_true', help='run Level 3 BLAS on devices (GPUs)' ), group_cat.add_argument( '--batch-blas3-device', action='store_true', help='run Level 3 Batch BLAS on devices (GPUs)' ), @@ -317,6 +318,9 @@ def filter_csv( values, csv ): [ 'trmv', dtype + layout + align + uplo + trans + diag + n + incx ], [ 'trsv', dtype + layout + align + uplo + trans + diag + n + incx ], ] + +if (opts.blas2_device): + cmds += [] # Level 3 if (opts.blas3):