diff --git a/GNUmakefile b/GNUmakefile index 8a12fdd7..44f152b4 100644 --- a/GNUmakefile +++ b/GNUmakefile @@ -37,6 +37,14 @@ make.inc: RANLIB ?= ranlib prefix ?= /opt/slate +NVCC ?= nvcc +HIPCC ?= hipcc +hipify ?= hipify-perl +md5sum ?= tools/md5sum.pl + +NVCCFLAGS += -O3 -std=c++11 --compiler-options '-Wall -Wno-unused-function' +HIPCCFLAGS += -std=c++11 -DTCE_HIP -fno-gpu-rdc + abs_prefix := ${abspath ${prefix}} # Default LD=ld won't work; use CXX. Can override in make.inc or environment. @@ -52,11 +60,25 @@ ifneq ($(findstring darwin, $(ostype)),) macos = 1 endif +#------------------------------------------------------------------------------- +# Detect which gpu_backend used +cuda = 0 +hip = 0 +sycl = 0 + +ifeq ($(gpu_backend),cuda) + cuda = 1 +else ifeq ($(gpu_backend),hip) + hip = 1 +endif + #------------------------------------------------------------------------------- # if shared ifneq ($(static),1) CXXFLAGS += -fPIC LDFLAGS += -fPIC + NVCCFLAGS += --compiler-options '-fPIC' + HIPCCFLAGS += -fPIC lib_ext = so else lib_ext = a @@ -77,7 +99,19 @@ lib_src = $(wildcard src/*.cc) lib_obj = $(addsuffix .o, $(basename $(lib_src))) dep += $(addsuffix .d, $(basename $(lib_src))) +cuda_src = $(wildcard test/cuda/*.cu) +hip_src = $(patsubst test/cuda/%.cu,test/hip/%.hip.cc,$(cuda_src)) + tester_src = $(wildcard test/*.cc) + +ifeq ($(cuda),1) + tester_src += $(cuda_src) +endif + +ifeq ($(hip),1) + tester_src += $(hip_src) +endif + tester_obj = $(addsuffix .o, $(basename $(tester_src))) dep += $(addsuffix .d, $(basename $(tester_src))) @@ -123,6 +157,8 @@ src/version.o: .id #------------------------------------------------------------------------------- # BLAS++ specific flags and libraries CXXFLAGS += -I./include +NVCCFLAGS += -I./include +HIPCCFLAGS += -I./include # additional flags and libraries for testers $(tester_obj): CXXFLAGS += -I$(testsweeper_dir) @@ -158,6 +194,59 @@ uninstall: $(RM) $(DESTDIR)$(abs_prefix)/lib$(LIB_SUFFIX)/libblaspp.* $(RM) $(DESTDIR)$(abs_prefix)/lib$(LIB_SUFFIX)/pkgconfig/blaspp.pc +#------------------------------------------------------------------------------- +# HIP sources converted from CUDA sources. + +# if_md5_outdated applies the given build rule ($1) only if the md5 sums +# of the target's dependency ($<) doesn't match that stored in the +# target's dep file ($@.dep). If the target ($@) is already up-to-date +# based on md5 sums, its timestamp is updated so make will recognize it +# as up-to-date. Otherwise, the target is built and its dep file +# updated. Instead of depending on the src file, the target depends on +# the md5 file of the src file. This can be adapted for multiple dependencies. +# Example usage: +# +# %: %.c.md5 +# ${call if_md5_outdated,\ +# gcc -o $@ ${basename $<}} +# +define if_md5_outdated + if [ -e $@ ] && diff $< $@.dep > /dev/null 2>&1; then \ + echo " make: '$@' is up-to-date based on md5sum."; \ + echo " touch $@"; \ + touch $@; \ + else \ + echo " make: '$@' is out-of-date based on md5sum."; \ + echo " ${strip $1}"; \ + $1; \ + cp $< $@.dep; \ + fi +endef + +# From GNU manual: Commas ... cannot appear in an argument as written. +# The[y] can be put into the argument value by variable substitution. +comma := , + +# Convert CUDA => HIP code. +# Explicitly mention ${hip_src}, ${hip_hdr}, ${md5_files} +# to prevent them from being intermediate files, +# so they are _always_ generated and never removed. +# Perl updates includes and removes excess spaces that fail style hook. +${hip_src}: test/hip/%.hip.cc: test/cuda/%.cu.md5 | test/hip + @${call if_md5_outdated, \ + ${hipify} ${basename $<} > $@; \ + perl -pi -e 's/\.cuh/.hip.hh/g; s/ +(${comma}|;|$$)/$$1/g;' $@} + +hipify: ${hip_src} + +md5_files := ${addsuffix .md5, ${cuda_src}} + +${md5_files}: %.md5: % + ${md5sum} $< > $@ + +test/hip: + mkdir -p $@ + #------------------------------------------------------------------------------- # if re-configured, recompile everything $(lib_obj) $(tester_obj): make.inc @@ -286,9 +375,16 @@ hooks: ${hooks} cp $< $@ ; \ fi +# .hip.cc rule before .cc rule. +%.hip.o: %.hip.cc + $(HIPCC) $(HIPCCFLAGS) -c $< -o $@ + %.o: %.cc $(CXX) $(CXXFLAGS) -c $< -o $@ +%.o: %.cu + $(NVCC) $(NVCCFLAGS) -c $< -o $@ + # preprocess source %.i: %.cc $(CXX) $(CXXFLAGS) -I$(testsweeper_dir) -E $< -o $@ @@ -333,6 +429,24 @@ echo: @echo @echo "dep = $(dep)" @echo + @echo "---------- CUDA options" + @echo "cuda = '$(cuda)'" + @echo "NVCC = $(NVCC)" + @echo "NVCC_which = $(NVCC_which)" + @echo "CUDA_PATH = $(CUDA_PATH)" + @echo "NVCCFLAGS = $(NVCCFLAGS)" + @echo + @echo "---------- HIP options" + @echo "hip = '$(hip)'" + @echo "HIPCC = $(HIPCC)" + @echo "HIPCC_which = $(HIPCC_which)" + @echo "ROCM_PATH = $(ROCM_PATH)" + @echo "HIPCCFLAGS = $(HIPCCFLAGS)" + @echo "hipify = ${hipify}" + @echo "cuda_src = ${cuda_src}" + @echo "hip_src = ${hip_src}" + @echo "md5_files = $(md5_files)" + @echo @echo "testsweeper_dir = $(testsweeper_dir)" @echo "testsweeper_src = $(testsweeper_src)" @echo "testsweeper = $(testsweeper)" diff --git a/config/config.py b/config/config.py index 71d923d9..e151ca70 100644 --- a/config/config.py +++ b/config/config.py @@ -77,7 +77,7 @@ def define( var, value=None ): # ------------------------------------------------------------------------------ # variables to replace instead of appending/prepending -replace_vars = ['CC', 'CXX', 'NVCC', 'FC', 'AR', 'RANLIB', 'prefix'] +replace_vars = ['CC', 'CXX', 'NVCC', 'FC', 'AR', 'RANLIB', 'prefix', 'gpu_backend'] # ------------------------------------------------------------------------------ # map file extensions to languages @@ -615,6 +615,24 @@ def openmp( flags=['-fopenmp', '-qopenmp', '-openmp', '-omp', ''] ): # end # end +#------------------------------------------------------------------------------- +def float16( ): + ''' + Tests for _Float16 support from the compiler. + ''' + print_header( '_Float16 support' ) + src = 'config/return_float16.cc' + cxxflags = define('HAVE_ISO_FLOAT16') + print_test( cxxflags ) + env = {'CXXFLAGS': cxxflags} + (rc, out, err) = compile_run( src, env ) + print_result( "_Float16", rc ) + if (rc == 0): + environ.merge( env ) + else: + print_msg( font.red( 'skipping _Float16 search' ) ) +# end + #------------------------------------------------------------------------------- def cublas_library(): ''' @@ -752,6 +770,7 @@ def gpu_blas(): try: cublas_library() gpu_blas_found = True + environ.merge( {'gpu_backend' : 'cuda' } ) except Error as ex: if (gpu_backend == 'cuda'): raise ex # fatal @@ -763,6 +782,7 @@ def gpu_blas(): try: rocblas_library() gpu_blas_found = True + environ.merge( {'gpu_backend' : 'hip' } ) except Error as ex: if (gpu_backend in ('hip', 'rocm')): raise ex # fatal @@ -773,6 +793,7 @@ def gpu_blas(): if (not gpu_blas_found and test_sycl): try: sycl_onemkl_library() + environ.merge( {'gpu_backend' : 'sycl' } ) gpu_blas_found = True except Error as ex: if (gpu_backend == 'sycl'): diff --git a/config/return_float16.cc b/config/return_float16.cc new file mode 100644 index 00000000..928246b0 --- /dev/null +++ b/config/return_float16.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2017-2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include + +#include "config.h" + +//------------------------------------------------------------------------------ +int main() +{ + _Float16 a = 0.1; + _Float16 b = 0.2; + _Float16 c = a + b; + + printf( "%f + %f = %f -- expected 0.3\n", (float)a, (float)b, (float)c ); + return 0; +} diff --git a/configure.py b/configure.py index 00ba0900..6491a00d 100755 --- a/configure.py +++ b/configure.py @@ -57,6 +57,8 @@ def main(): #config.prog_cxx_flag( '-Wconversion' ) #config.prog_cxx_flag( '-Werror' ) + config.float16() + config.openmp() config.lapack.blas() diff --git a/include/blas/device_blas.hh b/include/blas/device_blas.hh index 2321bf96..a3530075 100644 --- a/include/blas/device_blas.hh +++ b/include/blas/device_blas.hh @@ -207,6 +207,18 @@ void swap( // Level 3 BLAS //------------------------------------------------------------------------------ +void gemm( + blas::Layout layout, + blas::Op transA, + blas::Op transB, + int64_t m, int64_t n, int64_t k, + float16 alpha, + float16 const* A, int64_t lda, + float16 const* B, int64_t ldb, + float16 beta, + float16* C, int64_t ldc, + blas::Queue& queue ); + void gemm( blas::Layout layout, blas::Op transA, diff --git a/include/blas/fortran.h b/include/blas/fortran.h index 89973302..417eb8d6 100644 --- a/include/blas/fortran.h +++ b/include/blas/fortran.h @@ -903,6 +903,20 @@ void BLAS_ztrsv_base( // ============================================================================= // Level 3 BLAS - Fortran prototypes +#if defined(BLAS_HAVE_MKL) +#include +// ----------------------------------------------------------------------------- +#define BLAS_hgemm BLAS_FORTRAN_NAME( hgemm, HGEMM ) +void BLAS_hgemm( + char const *transA, char const *transB, + blas_int const *m, blas_int const *n, blas_int const *k, + MKL_F16 const *alpha, + MKL_F16 const *A, blas_int const *lda, + MKL_F16 const *B, blas_int const *ldb, + MKL_F16 const *beta, + MKL_F16 *C, blas_int const *ldc ); +#endif + // ----------------------------------------------------------------------------- #define BLAS_sgemm_base BLAS_FORTRAN_NAME( sgemm, SGEMM ) void BLAS_sgemm_base( diff --git a/include/blas/util.hh b/include/blas/util.hh index 298fca51..c16bb032 100644 --- a/include/blas/util.hh +++ b/include/blas/util.hh @@ -14,8 +14,134 @@ #include +#include + +#ifdef BLAS_HAVE_CUBLAS +#include +#elif defined(BLAS_HAVE_ROCBLAS) +#include +#elif defined(BLAS_HAVE_MKL) +#include +#endif + namespace blas { +#ifdef BLAS_HAVE_ISO_FLOAT16 + using float16 = _Float16; + +#elif defined(BLAS_HAVE_CUBLAS) + using float16 = __half; + +#elif defined(BLAS_HAVE_ROCBLAS) + using float16 = rocblas_half; + +#else +class float16 { +#if defined(BLAS_HAVE_MKL) + using float16_ = MKL_F16; +#else + using float16_ = uint16_t; +#endif + +public: + float16() : data_( 0.0f ) { } + + // TODO manipulate the bits here + float16( float v ) { data_ = float_to_float16( v ); } + + // TODO manipulate the bits here + operator float() const { + return float16_to_float( data_ ); + } + +private: + float16_ data_; + + typedef union { + float16_ data; + struct { + unsigned int frac : 10; + unsigned int exp : 5; + unsigned int sign : 1; + } bits; + } float16_repr_data_t; + + typedef union { + float data; + struct { + unsigned int frac : 23; + unsigned int exp : 8; + unsigned int sign : 1; + } bits; + } float_repr_data_t; + + static float float16_to_float(float16_ x) { + float16_repr_data_t src; + float_repr_data_t dst; + + src.data = x; + dst.data = 0; + dst.bits.sign = src.bits.sign; + + if (src.bits.exp == 0x01fU) { + dst.bits.exp = 0xffU; + if (src.bits.frac > 0) { + dst.bits.frac = ((src.bits.frac | 0x200U) << 13); + } + } else if (src.bits.exp > 0x00U) { + dst.bits.exp = src.bits.exp + ((1 << 7) - (1 << 4)); + dst.bits.frac = (src.bits.frac << 13); + } else { + unsigned int v = (src.bits.frac << 13); + + if (v > 0) { + dst.bits.exp = 0x71; + while ((v & 0x800000UL) == 0) { + dst.bits.exp --; + v <<= 1; + } + dst.bits.frac = v; + } + } + + return dst.data; + } + + static float16_ float_to_float16(float x) { + float_repr_data_t src; + float16_repr_data_t dst; + + src.data = x; + dst.data = 0; + dst.bits.sign = src.bits.sign; + + if (src.bits.exp == 0x0ffU) { + dst.bits.exp = 0x01fU; + dst.bits.frac = (src.bits.frac >> 13); + if (src.bits.frac > 0) dst.bits.frac |= 0x200U; + } else if (src.bits.exp >= 0x08fU) { + dst.bits.exp = 0x01fU; + dst.bits.frac = 0x000U; + } else if (src.bits.exp >= 0x071U){ + dst.bits.exp = src.bits.exp + ((1 << 4) - (1 << 7)); + dst.bits.frac = (src.bits.frac >> 13); + } else if (src.bits.exp >= 0x067U){ + dst.bits.exp = 0x000; + if (src.bits.frac > 0) { + dst.bits.frac = (((1U << 23) | src.bits.frac) >> 14); + } else { + dst.bits.frac = 1; + } + } + + return dst.data; + } +}; +#endif + +inline float real(float16& a) { return float( a ); } +inline float imag(float16& a) { return 0.0f; } + /// Use to silence compiler warning of unused variable. #define blas_unused( var ) ((void)var) diff --git a/include/blas/wrappers.hh b/include/blas/wrappers.hh index 3afc14d0..8d68e59a 100644 --- a/include/blas/wrappers.hh +++ b/include/blas/wrappers.hh @@ -701,6 +701,17 @@ void trsv( // Level 3 BLAS //------------------------------------------------------------------------------ +void gemm( + blas::Layout layout, + blas::Op transA, + blas::Op transB, + int64_t m, int64_t n, int64_t k, + float16 alpha, + float16 const* A, int64_t lda, + float16 const* B, int64_t ldb, + float16 beta, + float16* C, int64_t ldc ); + void gemm( blas::Layout layout, blas::Op transA, diff --git a/make.inc.in b/make.inc.in index fae87ec5..a73d10fa 100644 --- a/make.inc.in +++ b/make.inc.in @@ -5,17 +5,19 @@ # CPATH: @CPATH@ # LIBRARY_PATH: @LIBRARY_PATH@ # -CXX = @CXX@ +CXX = @CXX@ -CXXFLAGS = @CXXFLAGS@ +CXXFLAGS = @CXXFLAGS@ # see include/blas/defines.h # @DEFINES@ -LDFLAGS = @LDFLAGS@ +LDFLAGS = @LDFLAGS@ -LIBS = @LIBS@ +LIBS = @LIBS@ -prefix = @prefix@ +prefix = @prefix@ -static = @static@ +static = @static@ + +gpu_backend = @gpu_backend@ diff --git a/src/cublas_wrappers.cc b/src/cublas_wrappers.cc index 4238d381..33b73d59 100644 --- a/src/cublas_wrappers.cc +++ b/src/cublas_wrappers.cc @@ -478,6 +478,30 @@ void copy( //------------------------------------------------------------------------------ // gemm +//------------------------------------------------------------------------------ +void gemm( + blas::Op transA, blas::Op transB, + device_blas_int m, device_blas_int n, device_blas_int k, + float16 alpha, + float16 const *dA, device_blas_int ldda, + float16 const *dB, device_blas_int lddb, + float16 beta, + float16 *dC, device_blas_int lddc, + blas::Queue& queue ) +{ + __half alpha_ = *((__half*)&alpha); + __half beta_ = *((__half*)&beta); + + blas_dev_call( + cublasHgemm( + queue.handle(), + op2cublas(transA), op2cublas(transB), + m, n, k, + &alpha_, (__half const*)dA, ldda, + (__half const*)dB, lddb, + &beta_, (__half*)dC, lddc ) ); +} + //------------------------------------------------------------------------------ void gemm( blas::Op transA, blas::Op transB, diff --git a/src/device_gemm.cc b/src/device_gemm.cc index a7af0b5a..f330aed7 100644 --- a/src/device_gemm.cc +++ b/src/device_gemm.cc @@ -103,6 +103,25 @@ void gemm( //============================================================================== // High-level overloaded wrappers call mid-level templated wrapper. +//------------------------------------------------------------------------------ +/// GPU device, float16 version. +/// @ingroup gemm +void gemm( + blas::Layout layout, + blas::Op transA, + blas::Op transB, + int64_t m, int64_t n, int64_t k, + float16 alpha, + float16 const* A, int64_t lda, + float16 const* B, int64_t ldb, + float16 beta, + float16* C, int64_t ldc, + blas::Queue& queue ) +{ + impl::gemm( layout, transA, transB, m, n, k, + alpha, A, lda, B, ldb, beta, C, ldc, queue ); +} + //------------------------------------------------------------------------------ /// GPU device, float version. /// @ingroup gemm diff --git a/src/device_internal.hh b/src/device_internal.hh index a70ba304..ddd5e6d0 100644 --- a/src/device_internal.hh +++ b/src/device_internal.hh @@ -480,6 +480,16 @@ void copy( // Level 3 BLAS - Device Interfaces //------------------------------------------------------------------------------ +void gemm( + blas::Op transA, blas::Op transB, + device_blas_int m, device_blas_int n, device_blas_int k, + float16 alpha, + float16 const *dA, device_blas_int ldda, + float16 const *dB, device_blas_int lddb, + float16 beta, + float16 *dC, device_blas_int lddc, + blas::Queue& queue ); + void gemm( blas::Op transA, blas::Op transB, device_blas_int m, device_blas_int n, device_blas_int k, diff --git a/src/gemm.cc b/src/gemm.cc index 2d57509d..bfc0d8fa 100644 --- a/src/gemm.cc +++ b/src/gemm.cc @@ -9,11 +9,35 @@ #include +#if defined(BLAS_HAVE_MKL) + #include +#endif + namespace blas { //============================================================================== namespace internal { +//------------------------------------------------------------------------------ +/// Low-level overload wrapper calls Fortran, float16 version. +/// @ingroup gemm_internal +inline void gemm( + char transA, char transB, + blas_int m, blas_int n, blas_int k, + float16 alpha, + float16 const* A, blas_int lda, + float16 const* B, blas_int ldb, + float16 beta, + float16* C, blas_int ldc ) +{ +#ifdef BLAS_HAVE_MKL + BLAS_hgemm( &transA, &transB, &m, &n, &k, + (MKL_F16*)&alpha, (MKL_F16*)A, &lda, + (MKL_F16*)B, &ldb, + (MKL_F16*)&beta, (MKL_F16*)C, &ldc ); +#endif +} + //------------------------------------------------------------------------------ /// Low-level overload wrapper calls Fortran, float version. /// @ingroup gemm_internal @@ -179,6 +203,24 @@ void gemm( // When calling a template, all the templated arguments (e.g., scalar_t) // must match types exactly. +//------------------------------------------------------------------------------ +/// CPU, float16 version. +/// @ingroup gemm +void gemm( + blas::Layout layout, + blas::Op transA, + blas::Op transB, + int64_t m, int64_t n, int64_t k, + float16 alpha, + float16 const* A, int64_t lda, + float16 const* B, int64_t ldb, + float16 beta, + float16* C, int64_t ldc ) +{ + impl::gemm( layout, transA, transB, m, n, k, + alpha, A, lda, B, ldb, beta, C, ldc ); +} + //------------------------------------------------------------------------------ /// CPU, float version. /// @ingroup gemm diff --git a/src/onemkl_wrappers.cc b/src/onemkl_wrappers.cc index e1f3d05e..340eb94e 100644 --- a/src/onemkl_wrappers.cc +++ b/src/onemkl_wrappers.cc @@ -496,6 +496,27 @@ void copy( //------------------------------------------------------------------------------ // gemm +//------------------------------------------------------------------------------ +//void gemm( +// blas::Op transA, blas::Op transB, +// device_blas_int m, device_blas_int n, device_blas_int k, +// float16 alpha, +// float16 const *dA, device_blas_int ldda, +// float16 const *dB, device_blas_int lddb, +// float16 beta, +// float16 *dC, device_blas_int lddc, +// blas::Queue& queue ) +//{ +// blas_dev_call( +// oneapi::mkl::blas::gemm( +// queue.stream(), +// op2onemkl( transA ), op2onemkl( transB ), +// m, n, k, +// (MKL_F16)alpha, (MKL_F16*)dA, ldda, +// (MKL_F16*)dB, lddb, +// (MKL_F16)beta, (MKL_F16*)dC, lddc ) ); +//} + //------------------------------------------------------------------------------ void gemm( blas::Op transA, blas::Op transB, diff --git a/src/rocblas_wrappers.cc b/src/rocblas_wrappers.cc index a0581bd0..7894cdf7 100644 --- a/src/rocblas_wrappers.cc +++ b/src/rocblas_wrappers.cc @@ -484,6 +484,31 @@ void copy( //------------------------------------------------------------------------------ // gemm //------------------------------------------------------------------------------ +void gemm( + blas::Op transA, blas::Op transB, + device_blas_int m, device_blas_int n, device_blas_int k, + float16 alpha, + float16 const *dA, device_blas_int ldda, + float16 const *dB, device_blas_int lddb, + float16 beta, + float16 *dC, device_blas_int lddc, + blas::Queue& queue ) +{ + // Cast from blas::float16 to rocblas_half + rocblas_half alpha_ = *(rocblas_half*)α + rocblas_half beta_ = *(rocblas_half*)β + + blas_dev_call( + rocblas_hgemm( + queue.handle(), + op2rocblas(transA), op2rocblas(transB), + m, n, k, + &alpha_, (rocblas_half const*)dA, ldda, + (rocblas_half const*)dB, lddb, + &beta_, (rocblas_half*)dC, lddc ) ); +} +// +//------------------------------------------------------------------------------ void gemm( blas::Op transA, blas::Op transB, device_blas_int m, device_blas_int n, device_blas_int k, diff --git a/test/cblas_wrappers.hh b/test/cblas_wrappers.hh index 60b9a8bc..0e8c2ac6 100644 --- a/test/cblas_wrappers.hh +++ b/test/cblas_wrappers.hh @@ -1156,7 +1156,24 @@ cblas_syr2( // ============================================================================= // Level 3 BLAS +#if defined(BLAS_HAVE_MKL) // ----------------------------------------------------------------------------- +inline void +cblas_gemm( + CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, + int m, int n, int k, + blas::float16 alpha, + blas::float16 const *A, int lda, + blas::float16 const *B, int ldb, + blas::float16 beta, + blas::float16* C, int ldc ) +{ + cblas_hgemm( layout, transA, transB, m, n, k, + (MKL_F16)alpha, (MKL_F16*)A, lda, (MKL_F16*)B, ldb, + (MKL_F16)beta, (MKL_F16*)C, ldc ); +} +#endif + inline void cblas_gemm( CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, diff --git a/test/check_gemm.hh b/test/check_gemm.hh index 385a25a0..6c7433e9 100644 --- a/test/check_gemm.hh +++ b/test/check_gemm.hh @@ -13,11 +13,21 @@ #include +namespace std { + template <> class numeric_limits { + public: + static blas::float16 epsilon() { + // Value coming from MAGMA testing/testing_hgemm.cpp + return blas::float16( 0.00097656 ); + } + }; +}; // namespace std + // ----------------------------------------------------------------------------- // Computes error for multiplication with general matrix result. // Covers dot, gemv, ger, geru, gemm, symv, hemv, symm, trmv, trsv?, trmm, trsm?. // Cnorm is norm of original C, before multiplication operation. -template +template void check_gemm( int64_t m, int64_t n, int64_t k, T alpha, @@ -54,8 +64,8 @@ void check_gemm( real_t work[1], Cout_norm; Cout_norm = lapack_lange( "f", m, n, C, ldc, work ); error[0] = Cout_norm - / (sqrt( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm - + 2 * abs( beta ) * Cnorm); + / ( ( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm + + 2 * abs( beta ) * Cnorm ); if (verbose) { printf( "error: ||Cout||=%.2e / (sqrt(k=%lld + 2)" " * |alpha|=%.2e * ||A||=%.2e * ||B||=%.2e" @@ -70,7 +80,7 @@ void check_gemm( error[0] /= 2*sqrt(2); } - real_t u = 0.5 * std::numeric_limits< real_t >::epsilon(); + real_t u = 0.5 * std::numeric_limits>::epsilon(); *okay = (error[0] < u); #undef C diff --git a/test/cuda/utils.cu b/test/cuda/utils.cu new file mode 100644 index 00000000..ded6adc6 --- /dev/null +++ b/test/cuda/utils.cu @@ -0,0 +1,180 @@ +#include "../utils.hh" + +#if defined(BLAS_HAVE_CUBLAS) + #include +#elif defined(BLAS_HAVE_ROCBLAS) + #include +#endif + +//------------------------------------------------------------------------------ +/// @return ceil( x / y ), for integer type T. +template +inline constexpr T ceildiv( T x, T y ) +{ + return T( (x + y - 1) / y ); +} + +//============================================================================== +// Overloads to enable templating. + +//------------------------------------------------------------------------------ +// Template implementation when C++ has default rules for conversion +// (or no conversion needed). +template +__host__ __device__ +inline void copy_scalar( src_t src, dst_t& dst ) +{ + dst = dst_t( src ); +} + +//------------------------------------------------------------------------------ +// Overloaded implementations for specific cases. +__host__ __device__ +inline void copy_scalar( float src, __half& dst ) +{ + dst = __float2half( src ); +} + +//------------------------------------------------------------------------------ +__host__ __device__ +inline void copy_scalar( double src, __half& dst ) +{ +#ifdef __NVCC__ + dst = __double2half( src ); +#else + dst = __float2half( (float)src ); +#endif +} + +//------------------------------------------------------------------------------ +__host__ __device__ +inline void copy_scalar( __half src, float& dst ) +{ + dst = __half2float( src ); +} + +//------------------------------------------------------------------------------ +__host__ __device__ +inline void copy_scalar( __half src, double& dst ) +{ + // no __half2double, so do in 2 steps + dst = double( __half2float( src ) ); +} + +//============================================================================== +// GPU function. +const int blk_x = 32; +const int blk_y = 4; + +//------------------------------------------------------------------------------ +// GPU device routine, called from GPU kernel. +// +// Each thread-block does a blk_x by blk_y block of the matrix; +// each thread does 1 entry in the block. +// Because of the CUDA max y grid dimension of 65535, the entire grid is +// repeated in the y dimension with step = gridDim.y * blockDim.y, +// so thread (i, j) will do entries (i, j), (i, j + step), (i, j + 2*step), ... +// The max x grid dimension is 2^31-1, so there's no need to repeat. +// Cf. magma/magmablas/slag2h.cu +// +template +__device__ +void copy_matrix_device( + int m, int n, + src_t const* src, int ld_src, + dst_t* dst, int ld_dst ) +{ + // Global thread index. + const int ti = blockIdx.x * blockDim.x + threadIdx.x; + const int tj = blockIdx.y * blockDim.y + threadIdx.y; + + if (ti < m) { + for (int j = tj; j < n; j += gridDim.y * blockDim.y) { + copy_scalar( src[ ti + j*ld_src ], + dst[ ti + j*ld_dst ] ); + } + } +} + +//------------------------------------------------------------------------------ +// GPU kernel, called from CPU driver. +template +__global__ +void copy_matrix_kernel( + int m, int n, + src_t const* src, int ld_src, + dst_t* dst, int ld_dst ) +{ + copy_matrix_device( m, n, src, ld_src, dst, ld_dst ); +} + +namespace blas { + +//------------------------------------------------------------------------------ +// Copy m-by-n src matrix to dst matrix, with type conversion. +template +void copy_matrix( + int m, int n, + src_t const* src, int ld_src, + dst_t* dst, int ld_dst, + blas::Queue &queue) +{ + + cudaStream_t stream = queue.stream(); + + // CUDA has max x grid dimension of 2^31-1; y, z grid dimensions of 65535. + dim3 threads( blk_x, blk_y ); + dim3 blocks( ceildiv( m, blk_x ), std::min( 4, ceildiv( n, blk_y ) ) ); + + // printf( "%s: m %d, n %d; threads %d, %d, %d; blocks %d, %d, %d\n", + // __func__, + // threads.x, threads.y, threads.z, + // blocks.x, blocks.y, blocks.z ); + + copy_matrix_kernel<<< blocks, threads, 0, stream >>> + ( m, n, src, ld_src, dst, ld_dst ); + + // Check that launch succeeded. This doesn't say execution will be + // successful, it checks only the launch. + blas_dev_call( + cudaGetLastError() ); +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template <> +void copy_matrix( + int m, int n, + float const* src, int ld_src, + float16* dst, int ld_dst, + blas::Queue &queue) +{ + copy_matrix( + m, n, + src, ld_src, + (__half*) dst, ld_dst, + queue ); +} + +template <> +void copy_matrix( + int m, int n, + float16 const* src, int ld_src, + float* dst, int ld_dst, + blas::Queue &queue) +{ + copy_matrix( + m, n, + (__half*) src, ld_src, + dst, ld_dst, + queue ); +} + +template <> +void copy_matrix( + int m, int n, + float const* src, int ld_src, + float* dst, int ld_dst, + blas::Queue &queue); + +} // namespace blas diff --git a/test/hip/utils.hip.cc b/test/hip/utils.hip.cc new file mode 100644 index 00000000..cd46eeda --- /dev/null +++ b/test/hip/utils.hip.cc @@ -0,0 +1,180 @@ +#include "hip/hip_runtime.h" +#include "../utils.hh" + +#if defined(BLAS_HAVE_CUBLAS) + #include +#elif defined(BLAS_HAVE_ROCBLAS) + #include +#endif + +//------------------------------------------------------------------------------ +/// @return ceil( x / y ), for integer type T. +template +inline constexpr T ceildiv( T x, T y ) +{ + return T( (x + y - 1) / y ); +} + +//============================================================================== +// Overloads to enable templating. + +//------------------------------------------------------------------------------ +// Template implementation when C++ has default rules for conversion +// (or no conversion needed). +template +__host__ __device__ +inline void copy_scalar( src_t src, dst_t& dst ) +{ + dst = dst_t( src ); +} + +//------------------------------------------------------------------------------ +// Overloaded implementations for specific cases. +__host__ __device__ +inline void copy_scalar( float src, __half& dst ) +{ + dst = __float2half( src ); +} + +//------------------------------------------------------------------------------ +__host__ __device__ +inline void copy_scalar( double src, __half& dst ) +{ +#ifdef __NVCC__ + dst = __double2half( src ); +#else + dst = __float2half( (float)src ); +#endif +} + +//------------------------------------------------------------------------------ +__host__ __device__ +inline void copy_scalar( __half src, float& dst ) +{ + dst = __half2float( src ); +} + +//------------------------------------------------------------------------------ +__host__ __device__ +inline void copy_scalar( __half src, double& dst ) +{ + // no __half2double, so do in 2 steps + dst = double( __half2float( src ) ); +} + +//============================================================================== +// GPU function. +const int blk_x = 32; +const int blk_y = 4; + +//------------------------------------------------------------------------------ +// GPU device routine, called from GPU kernel. +// +// Each thread-block does a blk_x by blk_y block of the matrix; +// each thread does 1 entry in the block. +// Because of the CUDA max y grid dimension of 65535, the entire grid is +// repeated in the y dimension with step = gridDim.y * blockDim.y, +// so thread (i, j) will do entries (i, j), (i, j + step), (i, j + 2*step), ... +// The max x grid dimension is 2^31-1, so there's no need to repeat. +// Cf. magma/magmablas/slag2h.cu +// +template +__device__ +void copy_matrix_device( + int m, int n, + src_t const* src, int ld_src, + dst_t* dst, int ld_dst ) +{ + // Global thread index. + const int ti = blockIdx.x * blockDim.x + threadIdx.x; + const int tj = blockIdx.y * blockDim.y + threadIdx.y; + + if (ti < m) { + for (int j = tj; j < n; j += gridDim.y * blockDim.y) { + copy_scalar( src[ ti + j*ld_src ], + dst[ ti + j*ld_dst ] ); + } + } +} + +//------------------------------------------------------------------------------ +// GPU kernel, called from CPU driver. +template +__global__ +void copy_matrix_kernel( + int m, int n, + src_t const* src, int ld_src, + dst_t* dst, int ld_dst ) +{ + copy_matrix_device( m, n, src, ld_src, dst, ld_dst ); +} + +namespace blas { + +//------------------------------------------------------------------------------ +// Copy m-by-n src matrix to dst matrix, with type conversion. +template +void copy_matrix( + int m, int n, + src_t const* src, int ld_src, + dst_t* dst, int ld_dst, + blas::Queue &queue) +{ + + hipStream_t stream = queue.stream(); + + // CUDA has max x grid dimension of 2^31-1; y, z grid dimensions of 65535. + dim3 threads( blk_x, blk_y ); + dim3 blocks( ceildiv( m, blk_x ), std::min( 4, ceildiv( n, blk_y ) ) ); + + // printf( "%s: m %d, n %d; threads %d, %d, %d; blocks %d, %d, %d\n", + // __func__, + // threads.x, threads.y, threads.z, + // blocks.x, blocks.y, blocks.z ); + + copy_matrix_kernel<<< blocks, threads, 0, stream >>> + ( m, n, src, ld_src, dst, ld_dst ); + + // Check that launch succeeded. This doesn't say execution will be + // successful, it checks only the launch. + blas_dev_call( + hipGetLastError() ); +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template <> +void copy_matrix( + int m, int n, + float const* src, int ld_src, + float16* dst, int ld_dst, + blas::Queue &queue) +{ + copy_matrix( + m, n, + src, ld_src, + (__half*) dst, ld_dst, + queue ); +} + +template <> +void copy_matrix( + int m, int n, + float16 const* src, int ld_src, + float* dst, int ld_dst, + blas::Queue &queue) +{ + copy_matrix( + m, n, + (__half*) src, ld_src, + dst, ld_dst, + queue ); +} + +template <> +void copy_matrix( + int m, int n, + float const* src, int ld_src, + float* dst, int ld_dst, + blas::Queue &queue); +} // namespace blas diff --git a/test/hip/utils.hip.cc.dep b/test/hip/utils.hip.cc.dep new file mode 100644 index 00000000..9b338030 --- /dev/null +++ b/test/hip/utils.hip.cc.dep @@ -0,0 +1 @@ +7277516b6e785d5947d13ce9cac5b4f4 test/cuda/utils.cu diff --git a/test/test.cc b/test/test.cc index 55dcbb03..8518f77d 100644 --- a/test/test.cc +++ b/test/test.cc @@ -212,7 +212,7 @@ Params::Params(): // ----- routine parameters // name, w, type, def, char2enum, enum2char, enum2str, help - datatype ( "type", 4, ParamType::List, DataType::Double, char2datatype, datatype2char, datatype2str, "s=single (float), d=double, c=complex-single, z=complex-double" ), + datatype ( "type", 4, ParamType::List, DataType::Double, char2datatype, datatype2char, datatype2str, "h=half, s=single (float), d=double, c=complex-single, z=complex-double" ), layout ( "layout", 6, ParamType::List, blas::Layout::ColMajor, blas::char2layout, blas::layout2char, blas::layout2str, "layout: r=row major, c=column major" ), format ( "format", 6, ParamType::List, blas::Format::LAPACK, blas::char2format, blas::format2char, blas::format2str, "format: l=lapack, t=tile" ), side ( "side", 6, ParamType::List, blas::Side::Left, blas::char2side, blas::side2char, blas::side2str, "side: l=left, r=right" ), diff --git a/test/test_gemm.cc b/test/test_gemm.cc index 55605ed3..549ac95e 100644 --- a/test/test_gemm.cc +++ b/test/test_gemm.cc @@ -10,6 +10,8 @@ #include "print_matrix.hh" #include "check_gemm.hh" +#include "utils.hh" + // ----------------------------------------------------------------------------- template void test_gemm_work( Params& params, bool run ) @@ -170,10 +172,194 @@ void test_gemm_work( Params& params, bool run ) delete[] Cref; } +// ----------------------------------------------------------------------------- +template <> +void test_gemm_work( + Params& params, bool run ) +{ + using namespace testsweeper; + using std::real; + using std::imag; + using blas::real; + using blas::imag; + using blas::Op; + using blas::Layout; + using scalar_hi = float; + using scalar_lo = blas::float16; + using real_t = blas::real_type< scalar_hi >; + + // get & mark input values + blas::Layout layout = params.layout(); + blas::Op transA = params.transA(); + blas::Op transB = params.transB(); + scalar_lo alpha = (scalar_lo)params.alpha(); + scalar_lo beta = (scalar_lo)params.beta(); + int64_t m = params.dim.m(); + int64_t n = params.dim.n(); + int64_t k = params.dim.k(); + int64_t align = params.align(); + int64_t verbose = params.verbose(); + + // mark non-standard output values + params.gflops(); + params.ref_time(); + params.ref_gflops(); + + if (! run) + return; + + // setup + int64_t Am = (transA == Op::NoTrans ? m : k); + int64_t An = (transA == Op::NoTrans ? k : m); + int64_t Bm = (transB == Op::NoTrans ? k : n); + int64_t Bn = (transB == Op::NoTrans ? n : k); + int64_t Cm = m; + int64_t Cn = n; + if (layout == Layout::RowMajor) { + std::swap( Am, An ); + std::swap( Bm, Bn ); + std::swap( Cm, Cn ); + } + int64_t lda = roundup( Am, align ); + int64_t ldb = roundup( Bm, align ); + int64_t ldc = roundup( Cm, align ); + size_t size_A = size_t(lda)*An; + size_t size_B = size_t(ldb)*Bn; + size_t size_C = size_t(ldc)*Cn; + blas::float16* A_lo = new blas::float16[ size_A ]; + blas::float16* B_lo = new blas::float16[ size_B ]; + blas::float16* C_lo = new blas::float16[ size_C ]; + float* A_hi = new float[ size_A ]; + float* B_hi = new float[ size_B ]; + float* C_hi = new float[ size_C ]; + float* Cref = new float[ size_C ]; + + int64_t idist = 1; + int iseed[4] = { 0, 0, 0, 1 }; + lapack_larnv( idist, iseed, size_A, A_hi ); + lapack_larnv( idist, iseed, size_B, B_hi ); + lapack_larnv( idist, iseed, size_C, C_hi ); + lapack_lacpy( "g", Cm, Cn, C_hi, ldc, Cref, ldc ); + + // Convert float->float16 + blas::copy_matrix( Am, An, A_hi, lda, A_lo, lda ); + blas::copy_matrix( Bm, Bn, B_hi, ldb, B_lo, ldb ); + blas::copy_matrix( Cm, Cn, C_hi, ldc, C_lo, ldc ); + + // norms for error check + real_t work[1]; + real_t Anorm = lapack_lange( "f", Am, An, A_hi, lda, work ); + real_t Bnorm = lapack_lange( "f", Bm, Bn, B_hi, ldb, work ); + real_t Cnorm = lapack_lange( "f", Cm, Cn, C_hi, ldc, work ); + + // test error exits + assert_throw( blas::gemm( Layout(0), transA, transB, m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( layout, Op(0), transB, m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( layout, transA, Op(0), m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( layout, transA, transB, -1, n, k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( layout, transA, transB, m, -1, k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( layout, transA, transB, m, n, -1, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, A_hi, m-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( Layout::ColMajor, Op::Trans, Op::NoTrans, m, n, k, alpha, A_hi, k-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( Layout::ColMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, A_hi, k-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, A_hi, k-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, Op::Trans, Op::NoTrans, m, n, k, alpha, A_hi, m-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, A_hi, m-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error ); + + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, A_hi, lda, B_hi, k-1, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k, alpha, A_hi, lda, B_hi, n-1, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, A_hi, lda, B_hi, n-1, beta, C_hi, ldc ), blas::Error ); + + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, A_hi, lda, B_hi, n-1, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::Trans, m, n, k, alpha, A_hi, lda, B_hi, k-1, beta, C_hi, ldc ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, A_hi, lda, B_hi, k-1, beta, C_hi, ldc ), blas::Error ); + + assert_throw( blas::gemm( Layout::ColMajor, transA, transB, m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, m-1 ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, transA, transB, m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, n-1 ), blas::Error ); + + if (verbose >= 1) { + printf( "\n" + "A Am=%5lld, An=%5lld, lda=%5lld, size=%10lld, norm %.2e\n" + "B Bm=%5lld, Bn=%5lld, ldb=%5lld, size=%10lld, norm %.2e\n" + "C Cm=%5lld, Cn=%5lld, ldc=%5lld, size=%10lld, norm %.2e\n", + llong( Am ), llong( An ), llong( lda ), llong( size_A ), Anorm, + llong( Bm ), llong( Bn ), llong( ldb ), llong( size_B ), Bnorm, + llong( Cm ), llong( Cn ), llong( ldc ), llong( size_C ), Cnorm ); + } + if (verbose >= 2) { + printf( "alpha = %.4e + %.4ei; beta = %.4e + %.4ei;\n", + real(alpha), imag(alpha), + real(beta), imag(beta) ); + printf( "A = " ); print_matrix( Am, An, A_hi, lda ); + printf( "B = " ); print_matrix( Bm, Bn, B_hi, ldb ); + printf( "C = " ); print_matrix( Cm, Cn, C_hi, ldc ); + } + + // run test + testsweeper::flush_cache( params.cache() ); + double time = get_wtime(); + blas::gemm( layout, transA, transB, m, n, k, + alpha, A_lo, lda, B_lo, ldb, beta, C_lo, ldc ); + time = get_wtime() - time; + + double gflop = blas::Gflop< scalar_lo >::gemm( m, n, k ); + params.time() = time; + params.gflops() = gflop / time; + + // Convert float16->float + blas::copy_matrix( Cm, Cn, C_lo, ldc, C_hi, ldc ); + + if (verbose >= 2) { + printf( "C2 = " ); print_matrix( Cm, Cn, C_hi, ldc ); + } + + if (params.ref() == 'y' || params.check() == 'y') { + // run reference + testsweeper::flush_cache( params.cache() ); + time = get_wtime(); + cblas_gemm( cblas_layout_const(layout), + cblas_trans_const(transA), + cblas_trans_const(transB), + m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, Cref, ldc ); + time = get_wtime() - time; + + params.ref_time() = time; + params.ref_gflops() = gflop / time; + + if (verbose >= 2) { + printf( "Cref = " ); print_matrix( Cm, Cn, Cref, ldc ); + } + + // check error compared to reference + real_t error; + bool okay; + check_gemm( + Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm, + Cref, ldc, C_hi, ldc, verbose, &error, &okay ); + params.error() = error; + params.okay() = okay; + } + + delete[] A_lo; + delete[] B_lo; + delete[] C_lo; + delete[] A_hi; + delete[] B_hi; + delete[] C_hi; + delete[] Cref; +} + // ----------------------------------------------------------------------------- void test_gemm( Params& params, bool run ) { switch (params.datatype()) { + case testsweeper::DataType::Half: + test_gemm_work< blas::float16, blas::float16, blas::float16 >( + params, run ); + break; + case testsweeper::DataType::Single: test_gemm_work< float, float, float >( params, run ); break; diff --git a/test/test_gemm_device.cc b/test/test_gemm_device.cc index 895cddb2..3ef74701 100644 --- a/test/test_gemm_device.cc +++ b/test/test_gemm_device.cc @@ -10,6 +10,8 @@ #include "print_matrix.hh" #include "check_gemm.hh" +#include "utils.hh" + // ----------------------------------------------------------------------------- template void test_gemm_device_work( Params& params, bool run ) @@ -197,11 +199,228 @@ void test_gemm_device_work( Params& params, bool run ) blas::device_free( dB, queue ); blas::device_free( dC, queue ); } +// +// ----------------------------------------------------------------------------- +template <> +void test_gemm_device_work( Params& params, bool run ) +{ + using namespace testsweeper; + using std::real; + using std::imag; + using blas::Op; + using blas::Layout; + using scalar_hi = float; + using scalar_lo = blas::float16; + using real_t = blas::real_type< scalar_hi >; + + // get & mark input values + blas::Layout layout = params.layout(); + blas::Op transA = params.transA(); + blas::Op transB = params.transB(); + scalar_lo alpha = (scalar_lo)params.alpha(); + scalar_lo beta = (scalar_lo)params.beta(); + int64_t m = params.dim.m(); + int64_t n = params.dim.n(); + int64_t k = params.dim.k(); + int64_t device = params.device(); + int64_t align = params.align(); + int64_t verbose = params.verbose(); + + // mark non-standard output values + params.gflops(); + params.ref_time(); + params.ref_gflops(); + + if (! run) + return; + + if (blas::get_device_count() == 0) { + params.msg() = "skipping: no GPU devices or no GPU support"; + return; + } + + // setup + int64_t Am = (transA == Op::NoTrans ? m : k); + int64_t An = (transA == Op::NoTrans ? k : m); + int64_t Bm = (transB == Op::NoTrans ? k : n); + int64_t Bn = (transB == Op::NoTrans ? n : k); + int64_t Cm = m; + int64_t Cn = n; + if (layout == Layout::RowMajor) { + std::swap( Am, An ); + std::swap( Bm, Bn ); + std::swap( Cm, Cn ); + } + int64_t lda = roundup( Am, align ); + int64_t ldb = roundup( Bm, align ); + int64_t ldc = roundup( Cm, align ); + size_t size_A = size_t(lda)*An; + size_t size_B = size_t(ldb)*Bn; + size_t size_C = size_t(ldc)*Cn; + blas::float16* A_lo = new blas::float16[ size_A ]; + blas::float16* B_lo = new blas::float16[ size_B ]; + blas::float16* C_lo = new blas::float16[ size_C ]; + float* A_hi = new float[ size_A ]; + float* B_hi = new float[ size_B ]; + float* C_hi = new float[ size_C ]; + float* Cref = new float[ size_C ]; + + // device specifics + blas::Queue queue( device ); + blas::float16* dA_lo; + blas::float16* dB_lo; + blas::float16* dC_lo; + float* dA_hi; + float* dB_hi; + float* dC_hi; + + dA_lo = blas::device_malloc( size_A, queue ); + dB_lo = blas::device_malloc( size_B, queue ); + dC_lo = blas::device_malloc( size_C, queue ); + dA_hi = blas::device_malloc( size_A, queue ); + dB_hi = blas::device_malloc( size_B, queue ); + dC_hi = blas::device_malloc( size_C, queue ); + + int64_t idist = 1; + int iseed[4] = { 0, 0, 0, 1 }; + lapack_larnv( idist, iseed, size_A, A_hi ); + lapack_larnv( idist, iseed, size_B, B_hi ); + lapack_larnv( idist, iseed, size_C, C_hi ); + lapack_lacpy( "g", Cm, Cn, C_hi, ldc, Cref, ldc ); + + blas::device_copy_matrix( Am, An, A_hi, lda, dA_hi, lda, queue ); + blas::device_copy_matrix( Bm, Bn, B_hi, ldb, dB_hi, ldb, queue ); + blas::device_copy_matrix( Cm, Cn, C_hi, ldc, dC_hi, ldc, queue ); + + // Convert float->float16 + blas::copy_matrix( Am, An, dA_hi, lda, dA_lo, lda, queue ); + blas::copy_matrix( Bm, Bn, dB_hi, ldb, dB_lo, ldb, queue ); + blas::copy_matrix( Cm, Cn, dC_hi, ldc, dC_lo, ldc, queue ); + queue.sync(); + + // norms for error check + real_t work[1]; + real_t Anorm = lapack_lange( "f", Am, An, A_hi, lda, work ); + real_t Bnorm = lapack_lange( "f", Bm, Bn, B_hi, ldb, work ); + real_t Cnorm = lapack_lange( "f", Cm, Cn, C_hi, ldc, work ); + + // test error exits + assert_throw( blas::gemm( Layout(0), transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( layout, Op(0), transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( layout, transA, Op(0), m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( layout, transA, transB, -1, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( layout, transA, transB, m, -1, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( layout, transA, transB, m, n, -1, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( Layout::ColMajor, Op::Trans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( Layout::ColMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, Op::Trans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); + + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, lda, dB_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k, alpha, dA_hi, lda, dB_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, dB_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error ); + + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, lda, dB_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::Trans, m, n, k, alpha, dA_hi, lda, dB_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, dB_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error ); + + assert_throw( blas::gemm( Layout::ColMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, m-1, queue ), blas::Error ); + assert_throw( blas::gemm( Layout::RowMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, n-1, queue ), blas::Error ); + + if (verbose >= 1) { + printf( "\n" + "A Am=%5lld, An=%5lld, lda=%5lld, size=%10lld, norm %.2e\n" + "B Bm=%5lld, Bn=%5lld, ldb=%5lld, size=%10lld, norm %.2e\n" + "C Cm=%5lld, Cn=%5lld, ldc=%5lld, size=%10lld, norm %.2e\n", + llong( Am ), llong( An ), llong( lda ), llong( size_A ), Anorm, + llong( Bm ), llong( Bn ), llong( ldb ), llong( size_B ), Bnorm, + llong( Cm ), llong( Cn ), llong( ldc ), llong( size_C ), Cnorm ); + } + if (verbose >= 2) { + printf( "alpha = %.4e + %.4ei; beta = %.4e + %.4ei;\n", + blas::real(alpha), blas::imag(alpha), + blas::real(beta), blas::imag(beta) ); + printf( "A = " ); print_matrix( Am, An, A_hi, lda ); + printf( "B = " ); print_matrix( Bm, Bn, B_hi, ldb ); + printf( "C = " ); print_matrix( Cm, Cn, C_hi, ldc ); + } + + // run test + testsweeper::flush_cache( params.cache() ); + double time = get_wtime(); + blas::gemm( layout, transA, transB, m, n, k, + alpha, dA_lo, lda, dB_lo, ldb, beta, dC_lo, ldc, queue ); + queue.sync(); + time = get_wtime() - time; + + double gflop = blas::Gflop< scalar_lo >::gemm( m, n, k ); + params.time() = time; + params.gflops() = gflop / time; + + // Convert float16->float + blas::copy_matrix( Cm, Cn, dC_lo, ldc, dC_hi, ldc, queue ); + blas::device_copy_matrix(Cm, Cn, dC_hi, ldc, C_hi, ldc, queue); + queue.sync(); + + if (verbose >= 2) { + printf( "C2 = " ); print_matrix( Cm, Cn, C_hi, ldc ); + } + + if (params.ref() == 'y' || params.check() == 'y') { + // run reference + testsweeper::flush_cache( params.cache() ); + time = get_wtime(); + cblas_gemm( cblas_layout_const(layout), + cblas_trans_const(transA), + cblas_trans_const(transB), + m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, Cref, ldc ); // keep it like this as it defines the reference + time = get_wtime() - time; + + params.ref_time() = time; + params.ref_gflops() = gflop / time; + + if (verbose >= 2) { + printf( "Cref = " ); print_matrix( Cm, Cn, Cref, ldc ); + } + + // check error compared to reference + real_t error; + bool okay; + check_gemm( + Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm, + Cref, ldc, C_hi, ldc, verbose, &error, &okay ); + params.error() = error; + params.okay() = okay; + } + + delete[] A_hi; + delete[] B_hi; + delete[] C_hi; + delete[] A_lo; + delete[] B_lo; + delete[] C_lo; + delete[] Cref; + + blas::device_free( dA_hi, queue ); + blas::device_free( dB_hi, queue ); + blas::device_free( dC_hi, queue ); + blas::device_free( dA_lo, queue ); + blas::device_free( dB_lo, queue ); + blas::device_free( dC_lo, queue ); +} // ----------------------------------------------------------------------------- void test_gemm_device( Params& params, bool run ) { switch (params.datatype()) { + case testsweeper::DataType::Half: + test_gemm_device_work< blas::float16, blas::float16, blas::float16 >( params, run ); + break; + case testsweeper::DataType::Single: test_gemm_device_work< float, float, float >( params, run ); break; diff --git a/test/utils.cc b/test/utils.cc new file mode 100644 index 00000000..798d6b45 --- /dev/null +++ b/test/utils.cc @@ -0,0 +1,33 @@ +#include "utils.hh" + +namespace blas { + +template +void copy_matrix( + int m, int n, + src_t const* src, int ld_src, + dst_t* dst, int ld_dst) +{ + #pragma omp parallel for collapse(2) + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + dst[i + j*ld_dst] = (dst_t)src[i + j*ld_src]; + } + } +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template +void copy_matrix( + int m, int n, + float const* src, int ld_src, + float16* dst, int ld_dst); + +template +void copy_matrix( + int m, int n, + float16 const* src, int ld_src, + float* dst, int ld_dst); + +} // namespace blas diff --git a/test/utils.hh b/test/utils.hh new file mode 100644 index 00000000..748ff3d2 --- /dev/null +++ b/test/utils.hh @@ -0,0 +1,22 @@ +#ifndef UTILS_HH +#define UTILS_HH + +#include "blas.hh" + +namespace blas { + +template +void copy_matrix( + int m, int n, + src_t const* src, int ld_src, + dst_t* dst, int ld_dst); + +template +void copy_matrix( + int m, int n, + src_t const* src, int ld_src, + dst_t* dst, int ld_dst, + blas::Queue &queue); +} // namespace blas + +#endif // UTILS_HH diff --git a/tools/md5sum.pl b/tools/md5sum.pl new file mode 100755 index 00000000..f7658bce --- /dev/null +++ b/tools/md5sum.pl @@ -0,0 +1,17 @@ +#!/usr/bin/perl +# +# Generate md5 sums of files, with output compatible with md5sum. +# Doesn't support any options of md5sum, though (--check, etc.). + +use strict; +use Digest::MD5; + +foreach my $filename (@ARGV) { + my $file; + if (not open( $file, '<', $filename )) { + warn "$0: $filename: $!\n"; + next; + } + binmode( $file ); + print Digest::MD5->new->addfile( $file )->hexdigest, " $filename\n"; +}