From 3998ebdc446351d1fefa608284ed6a02bcdd06a5 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Fri, 22 Sep 2023 15:36:56 +0000
Subject: [PATCH 01/25] hgemm: add half definition and interfaces

---
 include/blas/device_blas.hh | 12 ++++++++++++
 include/blas/util.hh        |  7 +++++++
 src/cublas_wrappers.cc      | 21 +++++++++++++++++++++
 src/device_gemm.cc          | 19 +++++++++++++++++++
 src/device_internal.hh      | 10 ++++++++++
 5 files changed, 69 insertions(+)

diff --git a/include/blas/device_blas.hh b/include/blas/device_blas.hh
index 2321bf96..c268cc74 100644
--- a/include/blas/device_blas.hh
+++ b/include/blas/device_blas.hh
@@ -255,6 +255,18 @@ void gemm(
     std::complex<double>*       C, int64_t ldc,
     blas::Queue& queue );
 
+void gemm(
+    blas::Layout layout,
+    blas::Op transA,
+    blas::Op transB,
+    int64_t m, int64_t n, int64_t k,
+    half alpha,
+    half const* A, int64_t lda,
+    half const* B, int64_t ldb,
+    half beta,
+    half*       C, int64_t ldc,
+    blas::Queue& queue );
+
 //------------------------------------------------------------------------------
 void hemm(
     blas::Layout layout,
diff --git a/include/blas/util.hh b/include/blas/util.hh
index 298fca51..5df085c0 100644
--- a/include/blas/util.hh
+++ b/include/blas/util.hh
@@ -14,8 +14,15 @@
 
 #include <assert.h>
 
+#include <cuda_fp16.h>
+
 namespace blas {
 
+using half = __half;
+
+inline float real(half& a) { return a; }
+inline float imag(half& a) { return 0.0; }
+
 /// Use to silence compiler warning of unused variable.
 #define blas_unused( var ) ((void)var)
 
diff --git a/src/cublas_wrappers.cc b/src/cublas_wrappers.cc
index 4238d381..d98eae99 100644
--- a/src/cublas_wrappers.cc
+++ b/src/cublas_wrappers.cc
@@ -566,6 +566,27 @@ void gemm(
             (cuDoubleComplex*) dC, lddc ) );
 }
 
+//------------------------------------------------------------------------------
+void gemm(
+    blas::Op transA, blas::Op transB,
+    device_blas_int m, device_blas_int n, device_blas_int k,
+    half alpha,
+    half const *dA, device_blas_int ldda,
+    half const *dB, device_blas_int lddb,
+    half beta,
+    half       *dC, device_blas_int lddc,
+    blas::Queue& queue )
+{
+    blas_dev_call(
+        cublasHgemm(
+            queue.handle(),
+            op2cublas(transA), op2cublas(transB),
+            m, n, k,
+            &alpha, dA, ldda,
+                    dB, lddb,
+            &beta,  dC, lddc ) );
+}
+
 //------------------------------------------------------------------------------
 // trsm
 //------------------------------------------------------------------------------
diff --git a/src/device_gemm.cc b/src/device_gemm.cc
index a7af0b5a..5e73aa77 100644
--- a/src/device_gemm.cc
+++ b/src/device_gemm.cc
@@ -103,6 +103,25 @@ void gemm(
 //==============================================================================
 // High-level overloaded wrappers call mid-level templated wrapper.
 
+//------------------------------------------------------------------------------
+/// GPU device, half version.
+/// @ingroup gemm
+void gemm(
+    blas::Layout layout,
+    blas::Op transA,
+    blas::Op transB,
+    int64_t m, int64_t n, int64_t k,
+    half alpha,
+    half const* A, int64_t lda,
+    half const* B, int64_t ldb,
+    half beta,
+    half*       C, int64_t ldc,
+    blas::Queue& queue )
+{
+    impl::gemm( layout, transA, transB, m, n, k,
+                alpha, A, lda, B, ldb, beta, C, ldc, queue );
+}
+
 //------------------------------------------------------------------------------
 /// GPU device, float version.
 /// @ingroup gemm
diff --git a/src/device_internal.hh b/src/device_internal.hh
index a70ba304..9426b655 100644
--- a/src/device_internal.hh
+++ b/src/device_internal.hh
@@ -480,6 +480,16 @@ void copy(
 // Level 3 BLAS - Device Interfaces
 
 //------------------------------------------------------------------------------
+void gemm(
+    blas::Op transA, blas::Op transB,
+    device_blas_int m, device_blas_int n, device_blas_int k,
+    half alpha,
+    half const *dA, device_blas_int ldda,
+    half const *dB, device_blas_int lddb,
+    half beta,
+    half       *dC, device_blas_int lddc,
+    blas::Queue& queue );
+
 void gemm(
     blas::Op transA, blas::Op transB,
     device_blas_int m, device_blas_int n, device_blas_int k,

From 60afc47cd6ffaaa1003b2c95d1b0d8854c500dce Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Fri, 22 Sep 2023 15:49:31 +0000
Subject: [PATCH 02/25] hgemm: add hgemm tester; conversion routine from and to
 half; Nvidia support for now.

---
 GNUmakefile              |  10 +-
 test/check_gemm.hh       |  14 ++-
 test/test.cc             |   2 +-
 test/test_gemm_device.cc | 216 +++++++++++++++++++++++++++++++++++++++
 test/utils.cu            | 155 ++++++++++++++++++++++++++++
 test/utils.cuh           |  19 ++++
 6 files changed, 412 insertions(+), 4 deletions(-)
 create mode 100644 test/utils.cu
 create mode 100644 test/utils.cuh

diff --git a/GNUmakefile b/GNUmakefile
index 8a12fdd7..d3767520 100644
--- a/GNUmakefile
+++ b/GNUmakefile
@@ -37,6 +37,10 @@ make.inc:
 RANLIB   ?= ranlib
 prefix   ?= /opt/slate
 
+NVCC     ?= nvcc
+
+NVCCFLAGS  += -O3 -std=c++11 --compiler-options '-Wall -Wno-unused-function'
+
 abs_prefix := ${abspath ${prefix}}
 
 # Default LD=ld won't work; use CXX. Can override in make.inc or environment.
@@ -77,7 +81,7 @@ lib_src  = $(wildcard src/*.cc)
 lib_obj  = $(addsuffix .o, $(basename $(lib_src)))
 dep     += $(addsuffix .d, $(basename $(lib_src)))
 
-tester_src = $(wildcard test/*.cc)
+tester_src = $(wildcard test/*.cc test/*.cu)
 tester_obj = $(addsuffix .o, $(basename $(tester_src)))
 dep       += $(addsuffix .d, $(basename $(tester_src)))
 
@@ -123,6 +127,7 @@ src/version.o: .id
 #-------------------------------------------------------------------------------
 # BLAS++ specific flags and libraries
 CXXFLAGS += -I./include
+NVCCFLAGS += -I./include
 
 # additional flags and libraries for testers
 $(tester_obj): CXXFLAGS += -I$(testsweeper_dir)
@@ -289,6 +294,9 @@ hooks: ${hooks}
 %.o: %.cc
 	$(CXX) $(CXXFLAGS) -c $< -o $@
 
+%.o: %.cu
+	$(NVCC) $(NVCCFLAGS) -c $< -o $@
+
 # preprocess source
 %.i: %.cc
 	$(CXX) $(CXXFLAGS) -I$(testsweeper_dir) -E $< -o $@
diff --git a/test/check_gemm.hh b/test/check_gemm.hh
index 385a25a0..26a15a98 100644
--- a/test/check_gemm.hh
+++ b/test/check_gemm.hh
@@ -13,11 +13,21 @@
 
 #include <limits>
 
+namespace std {
+  template <> class numeric_limits<half> {
+  public:
+    static half epsilon() {
+      // Value coming from MAGMA testing/testing_hgemm.cpp
+      return half( 0.00097656 );
+    }
+  };
+}; // namespace std
+
 // -----------------------------------------------------------------------------
 // Computes error for multiplication with general matrix result.
 // Covers dot, gemv, ger, geru, gemm, symv, hemv, symm, trmv, trsv?, trmm, trsm?.
 // Cnorm is norm of original C, before multiplication operation.
-template <typename T>
+template <typename T, typename err_prec = T>
 void check_gemm(
     int64_t m, int64_t n, int64_t k,
     T alpha,
@@ -70,7 +80,7 @@ void check_gemm(
         error[0] /= 2*sqrt(2);
     }
 
-    real_t u = 0.5 * std::numeric_limits< real_t >::epsilon();
+    real_t u = 0.5 * std::numeric_limits<blas::real_type<err_prec>>::epsilon();
     *okay = (error[0] < u);
 
     #undef C
diff --git a/test/test.cc b/test/test.cc
index 55dcbb03..8518f77d 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -212,7 +212,7 @@ Params::Params():
 
     // ----- routine parameters
     //          name,      w,    type,            def,                    char2enum,         enum2char,         enum2str,         help
-    datatype  ( "type",    4,    ParamType::List, DataType::Double,       char2datatype,     datatype2char,     datatype2str,     "s=single (float), d=double, c=complex-single, z=complex-double" ),
+    datatype  ( "type",    4,    ParamType::List, DataType::Double,       char2datatype,     datatype2char,     datatype2str,     "h=half, s=single (float), d=double, c=complex-single, z=complex-double" ),
     layout    ( "layout",  6,    ParamType::List, blas::Layout::ColMajor, blas::char2layout, blas::layout2char, blas::layout2str, "layout: r=row major, c=column major" ),
     format    ( "format",  6,    ParamType::List, blas::Format::LAPACK,   blas::char2format, blas::format2char, blas::format2str, "format: l=lapack, t=tile" ),
     side      ( "side",    6,    ParamType::List, blas::Side::Left,       blas::char2side,   blas::side2char,   blas::side2str,   "side: l=left, r=right" ),
diff --git a/test/test_gemm_device.cc b/test/test_gemm_device.cc
index 895cddb2..66a2b62f 100644
--- a/test/test_gemm_device.cc
+++ b/test/test_gemm_device.cc
@@ -10,6 +10,8 @@
 #include "print_matrix.hh"
 #include "check_gemm.hh"
 
+#include "utils.cuh"
+
 // -----------------------------------------------------------------------------
 template <typename TA, typename TB, typename TC>
 void test_gemm_device_work( Params& params, bool run )
@@ -197,11 +199,225 @@ void test_gemm_device_work( Params& params, bool run )
     blas::device_free( dB, queue );
     blas::device_free( dC, queue );
 }
+//
+// -----------------------------------------------------------------------------
+template <>
+void test_gemm_device_work<half,half,half>( Params& params, bool run )
+{
+    using namespace testsweeper;
+    using std::real;
+    using std::imag;
+    using blas::Op;
+    using blas::Layout;
+    using scalar_hi = float;
+    using scalar_lo = half;
+    using real_t   = blas::real_type< scalar_hi >;
+
+    // get & mark input values
+    blas::Layout layout = params.layout();
+    blas::Op transA     = params.transA();
+    blas::Op transB     = params.transB();
+    scalar_lo alpha     = params.alpha();
+    scalar_lo beta      = params.beta();
+    int64_t m           = params.dim.m();
+    int64_t n           = params.dim.n();
+    int64_t k           = params.dim.k();
+    int64_t device      = params.device();
+    int64_t align       = params.align();
+    int64_t verbose     = params.verbose();
+
+    // mark non-standard output values
+    params.gflops();
+    params.ref_time();
+    params.ref_gflops();
+
+    if (! run)
+        return;
+
+    if (blas::get_device_count() == 0) {
+        params.msg() = "skipping: no GPU devices or no GPU support";
+        return;
+    }
+
+    // setup
+    int64_t Am = (transA == Op::NoTrans ? m : k);
+    int64_t An = (transA == Op::NoTrans ? k : m);
+    int64_t Bm = (transB == Op::NoTrans ? k : n);
+    int64_t Bn = (transB == Op::NoTrans ? n : k);
+    int64_t Cm = m;
+    int64_t Cn = n;
+    if (layout == Layout::RowMajor) {
+        std::swap( Am, An );
+        std::swap( Bm, Bn );
+        std::swap( Cm, Cn );
+    }
+    int64_t lda = roundup( Am, align );
+    int64_t ldb = roundup( Bm, align );
+    int64_t ldc = roundup( Cm, align );
+    size_t size_A = size_t(lda)*An;
+    size_t size_B = size_t(ldb)*Bn;
+    size_t size_C = size_t(ldc)*Cn;
+    half* A_lo   = new half[ size_A ];
+    half* B_lo   = new half[ size_B ];
+    half* C_lo   = new half[ size_C ];
+    float* A_hi  = new float[ size_A ];
+    float* B_hi  = new float[ size_B ];
+    float* C_hi  = new float[ size_C ];
+    float* Cref  = new float[ size_C ];
+
+    // device specifics
+    blas::Queue queue( device );
+    half* dA_lo;
+    half* dB_lo;
+    half* dC_lo;
+    float* dA_hi;
+    float* dB_hi;
+    float* dC_hi;
+
+    dA_lo = blas::device_malloc<half>( size_A, queue );
+    dB_lo = blas::device_malloc<half>( size_B, queue );
+    dC_lo = blas::device_malloc<half>( size_C, queue );
+    dA_hi = blas::device_malloc<float>( size_A, queue );
+    dB_hi = blas::device_malloc<float>( size_B, queue );
+    dC_hi = blas::device_malloc<float>( size_C, queue );
+
+    int64_t idist = 1;
+    int iseed[4] = { 0, 0, 0, 1 };
+    lapack_larnv( idist, iseed, size_A, A_hi );
+    lapack_larnv( idist, iseed, size_B, B_hi );
+    lapack_larnv( idist, iseed, size_C, C_hi );
+    lapack_lacpy( "g", Cm, Cn, C_hi, ldc, Cref, ldc );
+
+    blas::device_copy_matrix(Am, An, A_hi, lda, dA_hi, lda, queue);
+    blas::device_copy_matrix(Bm, Bn, B_hi, ldb, dB_hi, ldb, queue);
+    blas::device_copy_matrix(Cm, Cn, C_hi, ldc, dC_hi, ldc, queue);
+
+    blas::copy_matrix( Am, An, dA_hi, lda, dA_lo, lda, queue );
+    blas::copy_matrix( Bm, Bn, dB_hi, ldb, dB_lo, ldb, queue );
+    blas::copy_matrix( Cm, Cn, dC_hi, ldc, dC_lo, ldc, queue );
+    queue.sync();
+
+    // norms for error check
+    real_t work[1];
+    real_t Anorm = lapack_lange( "f", Am, An, A_hi, lda, work );
+    real_t Bnorm = lapack_lange( "f", Bm, Bn, B_hi, ldb, work );
+    real_t Cnorm = lapack_lange( "f", Cm, Cn, C_hi, ldc, work );
+
+    // test error exits
+    assert_throw( blas::gemm( Layout(0), transA, transB,  m,  n,  k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( layout,    Op(0),  transB,  m,  n,  k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( layout,    transA, Op(0),   m,  n,  k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( layout,    transA, transB, -1,  n,  k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( layout,    transA, transB,  m, -1,  k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( layout,    transA, transB,  m,  n, -1, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans,   Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::Trans,     Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans,   Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::Trans,     Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans,   m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::Trans,     m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans,   m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::Trans,     m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::ColMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, m-1, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, n-1, queue ), blas::Error );
+
+    if (verbose >= 1) {
+        printf( "\n"
+                "A Am=%5lld, An=%5lld, lda=%5lld, size=%10lld, norm %.2e\n"
+                "B Bm=%5lld, Bn=%5lld, ldb=%5lld, size=%10lld, norm %.2e\n"
+                "C Cm=%5lld, Cn=%5lld, ldc=%5lld, size=%10lld, norm %.2e\n",
+                llong( Am ), llong( An ), llong( lda ), llong( size_A ), Anorm,
+                llong( Bm ), llong( Bn ), llong( ldb ), llong( size_B ), Bnorm,
+                llong( Cm ), llong( Cn ), llong( ldc ), llong( size_C ), Cnorm );
+    }
+    if (verbose >= 2) {
+        printf( "alpha = %.4e + %.4ei; beta = %.4e + %.4ei;\n",
+                blas::real(alpha), blas::imag(alpha),
+                blas::real(beta),  blas::imag(beta) );
+        printf( "A = "    ); print_matrix( Am, An, A_hi, lda );
+        printf( "B = "    ); print_matrix( Bm, Bn, B_hi, ldb );
+        printf( "C = "    ); print_matrix( Cm, Cn, C_hi, ldc );
+    }
+
+    // run test
+    testsweeper::flush_cache( params.cache() );
+    double time = get_wtime();
+    blas::gemm( layout, transA, transB, m, n, k,
+                alpha, dA_lo, lda, dB_lo, ldb, beta, dC_lo, ldc, queue );
+    queue.sync();
+    time = get_wtime() - time;
+
+    double gflop = blas::Gflop< scalar_hi >::gemm( m, n, k );
+    params.time()   = time;
+    params.gflops() = gflop / time;
+
+    blas::copy_matrix( Cm, Cn, dC_lo, ldc, dC_hi, ldc, queue );
+    blas::device_copy_matrix(Cm, Cn, dC_hi, ldc, C_hi, ldc, queue);
+    queue.sync();
+
+    if (verbose >= 2) {
+        printf( "C2 = " ); print_matrix( Cm, Cn, C_hi, ldc );
+    }
+
+    if (params.ref() == 'y' || params.check() == 'y') {
+        // run reference
+        testsweeper::flush_cache( params.cache() );
+        time = get_wtime();
+        cblas_gemm( cblas_layout_const(layout),
+                    cblas_trans_const(transA),
+                    cblas_trans_const(transB),
+                    m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, Cref, ldc ); // keep it like this as it defines the reference
+        time = get_wtime() - time;
+
+        params.ref_time()   = time;
+        params.ref_gflops() = gflop / time;
+
+        if (verbose >= 2) {
+            printf( "Cref = " ); print_matrix( Cm, Cn, Cref, ldc );
+        }
+
+        // check error compared to reference
+        real_t error;
+        bool okay;
+        check_gemm<float, half>( Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm,
+                    Cref, ldc, C_hi, ldc, verbose, &error, &okay );
+        params.error() = error;
+        params.okay() = okay;
+    }
+
+    delete[] A_hi;
+    delete[] B_hi;
+    delete[] C_hi;
+    delete[] A_lo;
+    delete[] B_lo;
+    delete[] C_lo;
+    delete[] Cref;
+
+    blas::device_free( dA_hi, queue );
+    blas::device_free( dB_hi, queue );
+    blas::device_free( dC_hi, queue );
+    blas::device_free( dA_lo, queue );
+    blas::device_free( dB_lo, queue );
+    blas::device_free( dC_lo, queue );
+}
 
 // -----------------------------------------------------------------------------
 void test_gemm_device( Params& params, bool run )
 {
     switch (params.datatype()) {
+        case testsweeper::DataType::Half:
+            test_gemm_device_work< half, half, half >( params, run );
+            break;
+
         case testsweeper::DataType::Single:
             test_gemm_device_work< float, float, float >( params, run );
             break;
diff --git a/test/utils.cu b/test/utils.cu
new file mode 100644
index 00000000..b9e16f46
--- /dev/null
+++ b/test/utils.cu
@@ -0,0 +1,155 @@
+#include "utils.cuh"
+
+//------------------------------------------------------------------------------
+/// @return ceil( x / y ), for integer type T.
+template <typename T>
+inline constexpr T ceildiv( T x, T y )
+{
+    return T( (x + y - 1) / y );
+}
+
+//==============================================================================
+// Overloads to enable templating.
+
+//------------------------------------------------------------------------------
+// Template implementation when C++ has default rules for conversion
+// (or no conversion needed).
+template <typename src_t, typename dst_t>
+__host__ __device__
+inline void copy_scalar( src_t src, dst_t& dst )
+{
+    dst = dst_t( src );
+}
+
+//------------------------------------------------------------------------------
+// Overloaded implementations for specific cases.
+__host__ __device__
+inline void copy_scalar( float src, __half& dst )
+{
+    dst = __float2half( src );
+}
+
+//------------------------------------------------------------------------------
+__host__ __device__
+inline void copy_scalar( double src, __half& dst )
+{
+    dst = __double2half( src );
+}
+
+//------------------------------------------------------------------------------
+__host__ __device__
+inline void copy_scalar( __half src, float& dst )
+{
+    dst = __half2float( src );
+}
+
+//------------------------------------------------------------------------------
+__host__ __device__
+inline void copy_scalar( __half src, double& dst )
+{
+    // no __half2double, so do in 2 steps
+    dst = double( __half2float( src ) );
+}
+
+//==============================================================================
+// GPU function.
+const int blk_x = 32;
+const int blk_y = 4;
+
+//------------------------------------------------------------------------------
+// GPU device routine, called from GPU kernel.
+//
+// Each thread-block does a blk_x by blk_y block of the matrix;
+// each thread does 1 entry in the block.
+// Because of the CUDA max y grid dimension of 65535, the entire grid is
+// repeated in the y dimension with step = gridDim.y * blockDim.y,
+// so thread (i, j) will do entries (i, j), (i, j + step), (i, j + 2*step), ...
+// The max x grid dimension is 2^31-1, so there's no need to repeat.
+// Cf. magma/magmablas/slag2h.cu
+//
+template <typename src_t, typename dst_t>
+__device__
+void copy_matrix_device(
+    int m, int n,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst )
+{
+    // Global thread index.
+    const int ti = blockIdx.x * blockDim.x + threadIdx.x;
+    const int tj = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (ti < m) {
+        for (int j = tj; j < n; j += gridDim.y * blockDim.y) {
+            copy_scalar( src[ ti + j*ld_src ],
+                         dst[ ti + j*ld_dst ] );
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// GPU kernel, called from CPU driver.
+template <typename src_t, typename dst_t>
+__global__
+void copy_matrix_kernel(
+    int m, int n,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst )
+{
+    copy_matrix_device( m, n, src, ld_src, dst, ld_dst );
+}
+
+namespace blas {
+
+//------------------------------------------------------------------------------
+// Copy m-by-n src matrix to dst matrix, with type conversion.
+template <typename src_t, typename dst_t>
+void copy_matrix(
+    int m, int n,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst,
+    blas::Queue &queue)
+{
+
+    cudaStream_t stream = queue.stream();
+
+    // CUDA has max x grid dimension of 2^31-1; y, z grid dimensions of 65535.
+    dim3 threads( blk_x, blk_y );
+    dim3 blocks( ceildiv( m, blk_x ), std::min( 4, ceildiv( n, blk_y ) ) );
+
+    // printf( "%s: m %d, n %d; threads %d, %d, %d; blocks %d, %d, %d\n",
+    //         __func__,
+    //         threads.x, threads.y, threads.z,
+    //         blocks.x,  blocks.y,  blocks.z );
+
+    copy_matrix_kernel<<< blocks, threads, 0, stream >>>
+        ( m, n, src, ld_src, dst, ld_dst );
+
+    // Check that launch succeeded. This doesn't say execution will be
+    // successful, it checks only the launch.
+    blas_dev_call(
+        cudaGetLastError() );
+}
+
+//------------------------------------------------------------------------------
+// Explicit instantiations.
+template
+void copy_matrix<float, half>(
+    int m, int n,
+    float const* src, int ld_src,
+    half*        dst, int ld_dst,
+    blas::Queue &queue);
+
+template
+void copy_matrix<half, float>(
+    int m, int n,
+    half const* src, int ld_src,
+    float*      dst, int ld_dst,
+    blas::Queue &queue);
+
+template
+void copy_matrix<float, float>(
+    int m, int n,
+    float const* src, int ld_src,
+    float*       dst, int ld_dst,
+    blas::Queue &queue);
+} // namespace blas
diff --git a/test/utils.cuh b/test/utils.cuh
new file mode 100644
index 00000000..e698ac67
--- /dev/null
+++ b/test/utils.cuh
@@ -0,0 +1,19 @@
+#ifndef UTILS_CUH
+#define UTILS_CUH
+
+#include "blas.hh"
+
+#include <cuda_runtime.h>
+
+namespace blas {
+
+template <typename src_t, typename dst_t = src_t>
+void copy_matrix(
+    int m, int n,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst,
+    blas::Queue &queue);
+
+} // namespace blas
+
+#endif // UTILS_CUH

From f51f0c2ca93e4275bcf8dd8cb4832e1154dca7b7 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Mon, 25 Sep 2023 14:35:38 +0000
Subject: [PATCH 03/25] utils: move utils into test/cuda; Add missing NVCC
 flag; Clean utils header.

---
 test/utils.cuh | 2 --
 1 file changed, 2 deletions(-)

diff --git a/test/utils.cuh b/test/utils.cuh
index e698ac67..3c481bf5 100644
--- a/test/utils.cuh
+++ b/test/utils.cuh
@@ -3,8 +3,6 @@
 
 #include "blas.hh"
 
-#include <cuda_runtime.h>
-
 namespace blas {
 
 template <typename src_t, typename dst_t = src_t>

From 98e452b48aa47d9efce20b33baae39ccb18a6008 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Mon, 25 Sep 2023 14:36:03 +0000
Subject: [PATCH 04/25] utils: move utils into test/cuda; Add missing NVCC
 flag; Clean utils header.

---
 GNUmakefile              | 1 +
 test/{ => cuda}/utils.cu | 0
 2 files changed, 1 insertion(+)
 rename test/{ => cuda}/utils.cu (100%)

diff --git a/GNUmakefile b/GNUmakefile
index d3767520..8c8af2d5 100644
--- a/GNUmakefile
+++ b/GNUmakefile
@@ -61,6 +61,7 @@ endif
 ifneq ($(static),1)
     CXXFLAGS += -fPIC
     LDFLAGS  += -fPIC
+    NVCCFLAGS  += --compiler-options '-fPIC'
     lib_ext = so
 else
     lib_ext = a
diff --git a/test/utils.cu b/test/cuda/utils.cu
similarity index 100%
rename from test/utils.cu
rename to test/cuda/utils.cu

From 64e669778315f50e473761d450a5f270a650ca7c Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 27 Sep 2023 20:08:37 +0000
Subject: [PATCH 05/25] half: Add hipify from the cuda utils; update the
 compilation chain to detect hip.

---
 GNUmakefile        | 107 ++++++++++++++++++++++++++++++++++++++++++++-
 config/config.py   |   3 ++
 make.inc.in        |  14 +++---
 test/cuda/utils.cu |   2 +-
 4 files changed, 118 insertions(+), 8 deletions(-)

diff --git a/GNUmakefile b/GNUmakefile
index 8c8af2d5..44f152b4 100644
--- a/GNUmakefile
+++ b/GNUmakefile
@@ -38,8 +38,12 @@ RANLIB   ?= ranlib
 prefix   ?= /opt/slate
 
 NVCC     ?= nvcc
+HIPCC    ?= hipcc
+hipify   ?= hipify-perl
+md5sum   ?= tools/md5sum.pl
 
 NVCCFLAGS  += -O3 -std=c++11 --compiler-options '-Wall -Wno-unused-function'
+HIPCCFLAGS += -std=c++11 -DTCE_HIP -fno-gpu-rdc
 
 abs_prefix := ${abspath ${prefix}}
 
@@ -56,12 +60,25 @@ ifneq ($(findstring darwin, $(ostype)),)
     macos = 1
 endif
 
+#-------------------------------------------------------------------------------
+# Detect which gpu_backend used
+cuda = 0
+hip  = 0
+sycl = 0
+
+ifeq ($(gpu_backend),cuda)
+		cuda = 1
+else ifeq ($(gpu_backend),hip)
+		hip = 1
+endif
+
 #-------------------------------------------------------------------------------
 # if shared
 ifneq ($(static),1)
     CXXFLAGS += -fPIC
     LDFLAGS  += -fPIC
     NVCCFLAGS  += --compiler-options '-fPIC'
+    HIPCCFLAGS += -fPIC
     lib_ext = so
 else
     lib_ext = a
@@ -82,7 +99,19 @@ lib_src  = $(wildcard src/*.cc)
 lib_obj  = $(addsuffix .o, $(basename $(lib_src)))
 dep     += $(addsuffix .d, $(basename $(lib_src)))
 
-tester_src = $(wildcard test/*.cc test/*.cu)
+cuda_src = $(wildcard test/cuda/*.cu)
+hip_src  = $(patsubst test/cuda/%.cu,test/hip/%.hip.cc,$(cuda_src))
+
+tester_src = $(wildcard test/*.cc)
+
+ifeq ($(cuda),1)
+    tester_src += $(cuda_src)
+endif
+
+ifeq ($(hip),1)
+    tester_src += $(hip_src)
+endif
+
 tester_obj = $(addsuffix .o, $(basename $(tester_src)))
 dep       += $(addsuffix .d, $(basename $(tester_src)))
 
@@ -129,6 +158,7 @@ src/version.o: .id
 # BLAS++ specific flags and libraries
 CXXFLAGS += -I./include
 NVCCFLAGS += -I./include
+HIPCCFLAGS += -I./include
 
 # additional flags and libraries for testers
 $(tester_obj): CXXFLAGS += -I$(testsweeper_dir)
@@ -164,6 +194,59 @@ uninstall:
 	$(RM) $(DESTDIR)$(abs_prefix)/lib$(LIB_SUFFIX)/libblaspp.*
 	$(RM) $(DESTDIR)$(abs_prefix)/lib$(LIB_SUFFIX)/pkgconfig/blaspp.pc
 
+#-------------------------------------------------------------------------------
+# HIP sources converted from CUDA sources.
+
+# if_md5_outdated applies the given build rule ($1) only if the md5 sums
+# of the target's dependency ($<) doesn't match that stored in the
+# target's dep file ($@.dep). If the target ($@) is already up-to-date
+# based on md5 sums, its timestamp is updated so make will recognize it
+# as up-to-date. Otherwise, the target is built and its dep file
+# updated. Instead of depending on the src file, the target depends on
+# the md5 file of the src file. This can be adapted for multiple dependencies.
+# Example usage:
+#
+# %: %.c.md5
+#     ${call if_md5_outdated,\
+#            gcc -o $@ ${basename $<}}
+#
+define if_md5_outdated
+    if [ -e $@ ] && diff $< $@.dep > /dev/null 2>&1; then \
+        echo "  make: '$@' is up-to-date based on md5sum."; \
+        echo "  touch $@"; \
+                touch $@; \
+    else \
+        echo "  make: '$@' is out-of-date based on md5sum."; \
+        echo "  ${strip $1}"; \
+        $1; \
+        cp $< $@.dep; \
+    fi
+endef
+
+# From GNU manual: Commas ... cannot appear in an argument as written.
+# The[y] can be put into the argument value by variable substitution.
+comma := ,
+
+# Convert CUDA => HIP code.
+# Explicitly mention ${hip_src}, ${hip_hdr}, ${md5_files}
+# to prevent them from being intermediate files,
+# so they are _always_ generated and never removed.
+# Perl updates includes and removes excess spaces that fail style hook.
+${hip_src}: test/hip/%.hip.cc: test/cuda/%.cu.md5 | test/hip
+	@${call if_md5_outdated, \
+	        ${hipify} ${basename $<} > $@; \
+	        perl -pi -e 's/\.cuh/.hip.hh/g; s/ +(${comma}|;|$$)/$$1/g;' $@}
+
+hipify: ${hip_src}
+
+md5_files := ${addsuffix .md5, ${cuda_src}}
+
+${md5_files}: %.md5: %
+	${md5sum} $< > $@
+
+test/hip:
+	mkdir -p $@
+
 #-------------------------------------------------------------------------------
 # if re-configured, recompile everything
 $(lib_obj) $(tester_obj): make.inc
@@ -292,6 +375,10 @@ hooks: ${hooks}
 		cp $< $@ ; \
 	fi
 
+# .hip.cc rule before .cc rule.
+%.hip.o: %.hip.cc
+	$(HIPCC) $(HIPCCFLAGS) -c $< -o $@
+
 %.o: %.cc
 	$(CXX) $(CXXFLAGS) -c $< -o $@
 
@@ -342,6 +429,24 @@ echo:
 	@echo
 	@echo "dep           = $(dep)"
 	@echo
+	@echo "---------- CUDA options"
+	@echo "cuda          = '$(cuda)'"
+	@echo "NVCC          = $(NVCC)"
+	@echo "NVCC_which    = $(NVCC_which)"
+	@echo "CUDA_PATH     = $(CUDA_PATH)"
+	@echo "NVCCFLAGS     = $(NVCCFLAGS)"
+	@echo
+	@echo "---------- HIP options"
+	@echo "hip           = '$(hip)'"
+	@echo "HIPCC         = $(HIPCC)"
+	@echo "HIPCC_which   = $(HIPCC_which)"
+	@echo "ROCM_PATH     = $(ROCM_PATH)"
+	@echo "HIPCCFLAGS    = $(HIPCCFLAGS)"
+	@echo "hipify        = ${hipify}"
+	@echo "cuda_src      = ${cuda_src}"
+	@echo "hip_src       = ${hip_src}"
+	@echo "md5_files     = $(md5_files)"
+	@echo
 	@echo "testsweeper_dir   = $(testsweeper_dir)"
 	@echo "testsweeper_src   = $(testsweeper_src)"
 	@echo "testsweeper       = $(testsweeper)"
diff --git a/config/config.py b/config/config.py
index 71d923d9..8de06489 100644
--- a/config/config.py
+++ b/config/config.py
@@ -752,6 +752,7 @@ def gpu_blas():
         try:
             cublas_library()
             gpu_blas_found = True
+            environ.merge( {'gpu_backend' : 'cuda' } )
         except Error as ex:
             if (gpu_backend == 'cuda'):
                 raise ex  # fatal
@@ -763,6 +764,7 @@ def gpu_blas():
         try:
             rocblas_library()
             gpu_blas_found = True
+            environ.merge( {'gpu_backend' : 'hip' } )
         except Error as ex:
             if (gpu_backend in ('hip', 'rocm')):
                 raise ex  # fatal
@@ -773,6 +775,7 @@ def gpu_blas():
     if (not gpu_blas_found and test_sycl):
         try:
             sycl_onemkl_library()
+            environ.merge( {'gpu_backend' : 'sycl' } )
             gpu_blas_found = True
         except Error as ex:
             if (gpu_backend == 'sycl'):
diff --git a/make.inc.in b/make.inc.in
index fae87ec5..a73d10fa 100644
--- a/make.inc.in
+++ b/make.inc.in
@@ -5,17 +5,19 @@
 # CPATH: @CPATH@
 # LIBRARY_PATH: @LIBRARY_PATH@
 #
-CXX      = @CXX@
+CXX         = @CXX@
 
-CXXFLAGS = @CXXFLAGS@
+CXXFLAGS    = @CXXFLAGS@
 
 # see include/blas/defines.h
 # @DEFINES@
 
-LDFLAGS  = @LDFLAGS@
+LDFLAGS     = @LDFLAGS@
 
-LIBS     = @LIBS@
+LIBS        = @LIBS@
 
-prefix   = @prefix@
+prefix      = @prefix@
 
-static   = @static@
+static      = @static@
+
+gpu_backend = @gpu_backend@
diff --git a/test/cuda/utils.cu b/test/cuda/utils.cu
index b9e16f46..588b1f28 100644
--- a/test/cuda/utils.cu
+++ b/test/cuda/utils.cu
@@ -1,4 +1,4 @@
-#include "utils.cuh"
+#include "../utils.cuh"
 
 //------------------------------------------------------------------------------
 /// @return ceil( x / y ), for integer type T.

From c7061d03793f1ed7bc63beb0eaa1c46d9a643b03 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 4 Oct 2023 17:47:56 -0400
Subject: [PATCH 06/25] Replace the definition of half from __half to _float16;
 rename blas::half into blas::float16

---
 include/blas/device_blas.hh | 24 ++++++++++----------
 include/blas/util.hh        |  8 +++----
 src/cublas_wrappers.cc      | 45 ++++++++++++++++++++-----------------
 src/device_gemm.cc          | 12 +++++-----
 src/device_internal.hh      | 10 ++++-----
 test/check_gemm.hh          |  8 +++----
 test/cuda/utils.cu          | 40 ++++++++++++++++++++++++---------
 test/test_gemm_device.cc    | 40 +++++++++++++++++----------------
 8 files changed, 104 insertions(+), 83 deletions(-)

diff --git a/include/blas/device_blas.hh b/include/blas/device_blas.hh
index c268cc74..a3530075 100644
--- a/include/blas/device_blas.hh
+++ b/include/blas/device_blas.hh
@@ -207,6 +207,18 @@ void swap(
 // Level 3 BLAS
 
 //------------------------------------------------------------------------------
+void gemm(
+    blas::Layout layout,
+    blas::Op transA,
+    blas::Op transB,
+    int64_t m, int64_t n, int64_t k,
+    float16 alpha,
+    float16 const* A, int64_t lda,
+    float16 const* B, int64_t ldb,
+    float16 beta,
+    float16*       C, int64_t ldc,
+    blas::Queue& queue );
+
 void gemm(
     blas::Layout layout,
     blas::Op transA,
@@ -255,18 +267,6 @@ void gemm(
     std::complex<double>*       C, int64_t ldc,
     blas::Queue& queue );
 
-void gemm(
-    blas::Layout layout,
-    blas::Op transA,
-    blas::Op transB,
-    int64_t m, int64_t n, int64_t k,
-    half alpha,
-    half const* A, int64_t lda,
-    half const* B, int64_t ldb,
-    half beta,
-    half*       C, int64_t ldc,
-    blas::Queue& queue );
-
 //------------------------------------------------------------------------------
 void hemm(
     blas::Layout layout,
diff --git a/include/blas/util.hh b/include/blas/util.hh
index 5df085c0..c88cfe2d 100644
--- a/include/blas/util.hh
+++ b/include/blas/util.hh
@@ -14,14 +14,12 @@
 
 #include <assert.h>
 
-#include <cuda_fp16.h>
-
 namespace blas {
 
-using half = __half;
+using float16 = _Float16;
 
-inline float real(half& a) { return a; }
-inline float imag(half& a) { return 0.0; }
+inline float real(float16& a) { return float( a ); }
+inline float imag(float16& a) { return 0.0f; }
 
 /// Use to silence compiler warning of unused variable.
 #define blas_unused( var ) ((void)var)
diff --git a/src/cublas_wrappers.cc b/src/cublas_wrappers.cc
index d98eae99..260ba763 100644
--- a/src/cublas_wrappers.cc
+++ b/src/cublas_wrappers.cc
@@ -478,6 +478,30 @@ void copy(
 
 //------------------------------------------------------------------------------
 // gemm
+//------------------------------------------------------------------------------
+void gemm(
+    blas::Op transA, blas::Op transB,
+    device_blas_int m, device_blas_int n, device_blas_int k,
+    float16 alpha,
+    float16 const *dA, device_blas_int ldda,
+    float16 const *dB, device_blas_int lddb,
+    float16 beta,
+    float16       *dC, device_blas_int lddc,
+    blas::Queue& queue )
+{
+    __half alpha_ = __half( alpha );
+    __half beta_  = __half( beta );
+
+    blas_dev_call(
+        cublasHgemm(
+            queue.handle(),
+            op2cublas(transA), op2cublas(transB),
+            m, n, k,
+            &alpha_, (__half const*)dA, ldda,
+                     (__half const*)dB, lddb,
+            &beta_,  (__half*)dC, lddc ) );
+}
+
 //------------------------------------------------------------------------------
 void gemm(
     blas::Op transA, blas::Op transB,
@@ -566,27 +590,6 @@ void gemm(
             (cuDoubleComplex*) dC, lddc ) );
 }
 
-//------------------------------------------------------------------------------
-void gemm(
-    blas::Op transA, blas::Op transB,
-    device_blas_int m, device_blas_int n, device_blas_int k,
-    half alpha,
-    half const *dA, device_blas_int ldda,
-    half const *dB, device_blas_int lddb,
-    half beta,
-    half       *dC, device_blas_int lddc,
-    blas::Queue& queue )
-{
-    blas_dev_call(
-        cublasHgemm(
-            queue.handle(),
-            op2cublas(transA), op2cublas(transB),
-            m, n, k,
-            &alpha, dA, ldda,
-                    dB, lddb,
-            &beta,  dC, lddc ) );
-}
-
 //------------------------------------------------------------------------------
 // trsm
 //------------------------------------------------------------------------------
diff --git a/src/device_gemm.cc b/src/device_gemm.cc
index 5e73aa77..f330aed7 100644
--- a/src/device_gemm.cc
+++ b/src/device_gemm.cc
@@ -104,18 +104,18 @@ void gemm(
 // High-level overloaded wrappers call mid-level templated wrapper.
 
 //------------------------------------------------------------------------------
-/// GPU device, half version.
+/// GPU device, float16 version.
 /// @ingroup gemm
 void gemm(
     blas::Layout layout,
     blas::Op transA,
     blas::Op transB,
     int64_t m, int64_t n, int64_t k,
-    half alpha,
-    half const* A, int64_t lda,
-    half const* B, int64_t ldb,
-    half beta,
-    half*       C, int64_t ldc,
+    float16 alpha,
+    float16 const* A, int64_t lda,
+    float16 const* B, int64_t ldb,
+    float16 beta,
+    float16*       C, int64_t ldc,
     blas::Queue& queue )
 {
     impl::gemm( layout, transA, transB, m, n, k,
diff --git a/src/device_internal.hh b/src/device_internal.hh
index 9426b655..ddd5e6d0 100644
--- a/src/device_internal.hh
+++ b/src/device_internal.hh
@@ -483,11 +483,11 @@ void copy(
 void gemm(
     blas::Op transA, blas::Op transB,
     device_blas_int m, device_blas_int n, device_blas_int k,
-    half alpha,
-    half const *dA, device_blas_int ldda,
-    half const *dB, device_blas_int lddb,
-    half beta,
-    half       *dC, device_blas_int lddc,
+    float16 alpha,
+    float16 const *dA, device_blas_int ldda,
+    float16 const *dB, device_blas_int lddb,
+    float16 beta,
+    float16       *dC, device_blas_int lddc,
     blas::Queue& queue );
 
 void gemm(
diff --git a/test/check_gemm.hh b/test/check_gemm.hh
index 26a15a98..f7ce94b0 100644
--- a/test/check_gemm.hh
+++ b/test/check_gemm.hh
@@ -14,11 +14,11 @@
 #include <limits>
 
 namespace std {
-  template <> class numeric_limits<half> {
+  template <> class numeric_limits<blas::float16> {
   public:
-    static half epsilon() {
-      // Value coming from MAGMA testing/testing_hgemm.cpp
-      return half( 0.00097656 );
+    static blas::float16 epsilon() {
+        // Value coming from MAGMA testing/testing_hgemm.cpp
+        return blas::float16( 0.00097656 );
     }
   };
 }; // namespace std
diff --git a/test/cuda/utils.cu b/test/cuda/utils.cu
index 588b1f28..01003173 100644
--- a/test/cuda/utils.cu
+++ b/test/cuda/utils.cu
@@ -1,4 +1,4 @@
-#include "../utils.cuh"
+#include "../utils.hh"
 
 //------------------------------------------------------------------------------
 /// @return ceil( x / y ), for integer type T.
@@ -33,7 +33,11 @@ inline void copy_scalar( float src, __half& dst )
 __host__ __device__
 inline void copy_scalar( double src, __half& dst )
 {
+#ifdef __NVCC__
     dst = __double2half( src );
+#else
+    dst = __float2half( (float)src );
+#endif
 }
 
 //------------------------------------------------------------------------------
@@ -132,21 +136,35 @@ void copy_matrix(
 
 //------------------------------------------------------------------------------
 // Explicit instantiations.
-template
-void copy_matrix<float, half>(
+template <>
+void copy_matrix<float, float16>(
     int m, int n,
     float const* src, int ld_src,
-    half*        dst, int ld_dst,
-    blas::Queue &queue);
+    float16*     dst, int ld_dst,
+    blas::Queue &queue)
+{
+    copy_matrix(
+        m, n,
+                  src, ld_src,
+        (__half*) dst, ld_dst,
+        queue );
+}
 
-template
-void copy_matrix<half, float>(
+template <>
+void copy_matrix<float16, float>(
     int m, int n,
-    half const* src, int ld_src,
-    float*      dst, int ld_dst,
-    blas::Queue &queue);
+    float16 const* src, int ld_src,
+    float*         dst, int ld_dst,
+    blas::Queue &queue)
+{
+    copy_matrix(
+        m, n,
+        (__half*) src, ld_src,
+                  dst, ld_dst,
+        queue );
+}
 
-template
+template <>
 void copy_matrix<float, float>(
     int m, int n,
     float const* src, int ld_src,
diff --git a/test/test_gemm_device.cc b/test/test_gemm_device.cc
index 66a2b62f..2344f108 100644
--- a/test/test_gemm_device.cc
+++ b/test/test_gemm_device.cc
@@ -202,7 +202,7 @@ void test_gemm_device_work( Params& params, bool run )
 //
 // -----------------------------------------------------------------------------
 template <>
-void test_gemm_device_work<half,half,half>( Params& params, bool run )
+void test_gemm_device_work<blas::float16,blas::float16,blas::float16>( Params& params, bool run )
 {
     using namespace testsweeper;
     using std::real;
@@ -210,7 +210,7 @@ void test_gemm_device_work<half,half,half>( Params& params, bool run )
     using blas::Op;
     using blas::Layout;
     using scalar_hi = float;
-    using scalar_lo = half;
+    using scalar_lo = blas::float16;
     using real_t   = blas::real_type< scalar_hi >;
 
     // get & mark input values
@@ -257,9 +257,9 @@ void test_gemm_device_work<half,half,half>( Params& params, bool run )
     size_t size_A = size_t(lda)*An;
     size_t size_B = size_t(ldb)*Bn;
     size_t size_C = size_t(ldc)*Cn;
-    half* A_lo   = new half[ size_A ];
-    half* B_lo   = new half[ size_B ];
-    half* C_lo   = new half[ size_C ];
+    blas::float16* A_lo   = new blas::float16[ size_A ];
+    blas::float16* B_lo   = new blas::float16[ size_B ];
+    blas::float16* C_lo   = new blas::float16[ size_C ];
     float* A_hi  = new float[ size_A ];
     float* B_hi  = new float[ size_B ];
     float* C_hi  = new float[ size_C ];
@@ -267,16 +267,16 @@ void test_gemm_device_work<half,half,half>( Params& params, bool run )
 
     // device specifics
     blas::Queue queue( device );
-    half* dA_lo;
-    half* dB_lo;
-    half* dC_lo;
+    blas::float16* dA_lo;
+    blas::float16* dB_lo;
+    blas::float16* dC_lo;
     float* dA_hi;
     float* dB_hi;
     float* dC_hi;
 
-    dA_lo = blas::device_malloc<half>( size_A, queue );
-    dB_lo = blas::device_malloc<half>( size_B, queue );
-    dC_lo = blas::device_malloc<half>( size_C, queue );
+    dA_lo = blas::device_malloc<blas::float16>( size_A, queue );
+    dB_lo = blas::device_malloc<blas::float16>( size_B, queue );
+    dC_lo = blas::device_malloc<blas::float16>( size_C, queue );
     dA_hi = blas::device_malloc<float>( size_A, queue );
     dB_hi = blas::device_malloc<float>( size_B, queue );
     dC_hi = blas::device_malloc<float>( size_C, queue );
@@ -292,6 +292,7 @@ void test_gemm_device_work<half,half,half>( Params& params, bool run )
     blas::device_copy_matrix(Bm, Bn, B_hi, ldb, dB_hi, ldb, queue);
     blas::device_copy_matrix(Cm, Cn, C_hi, ldc, dC_hi, ldc, queue);
 
+    // Convert float->float16
     blas::copy_matrix( Am, An, dA_hi, lda, dA_lo, lda, queue );
     blas::copy_matrix( Bm, Bn, dB_hi, ldb, dB_lo, ldb, queue );
     blas::copy_matrix( Cm, Cn, dC_hi, ldc, dC_lo, ldc, queue );
@@ -319,13 +320,13 @@ void test_gemm_device_work<half,half,half>( Params& params, bool run )
     assert_throw( blas::gemm( Layout::RowMajor, Op::Trans,     Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
     assert_throw( blas::gemm( Layout::RowMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
 
-    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans,   m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
-    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::Trans,     m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
-    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans,   m, n, k, alpha, dA_hi, lda, dB_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::Trans,     m, n, k, alpha, dA_hi, lda, dB_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, dB_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
 
-    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans,   m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
-    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::Trans,     m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
-    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans,   m, n, k, alpha, dA_hi, lda, dB_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::Trans,     m, n, k, alpha, dA_hi, lda, dB_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, dB_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
 
     assert_throw( blas::gemm( Layout::ColMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, m-1, queue ), blas::Error );
     assert_throw( blas::gemm( Layout::RowMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, n-1, queue ), blas::Error );
@@ -360,6 +361,7 @@ void test_gemm_device_work<half,half,half>( Params& params, bool run )
     params.time()   = time;
     params.gflops() = gflop / time;
 
+    // Convert float16->float
     blas::copy_matrix( Cm, Cn, dC_lo, ldc, dC_hi, ldc, queue );
     blas::device_copy_matrix(Cm, Cn, dC_hi, ldc, C_hi, ldc, queue);
     queue.sync();
@@ -388,7 +390,7 @@ void test_gemm_device_work<half,half,half>( Params& params, bool run )
         // check error compared to reference
         real_t error;
         bool okay;
-        check_gemm<float, half>( Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm,
+        check_gemm<float, blas::float16>( Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm,
                     Cref, ldc, C_hi, ldc, verbose, &error, &okay );
         params.error() = error;
         params.okay() = okay;
@@ -415,7 +417,7 @@ void test_gemm_device( Params& params, bool run )
 {
     switch (params.datatype()) {
         case testsweeper::DataType::Half:
-            test_gemm_device_work< half, half, half >( params, run );
+            test_gemm_device_work< blas::float16, blas::float16, blas::float16 >( params, run );
             break;
 
         case testsweeper::DataType::Single:

From 68cb36de6e1f5d905d975490cfe50a8b19113558 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 4 Oct 2023 17:51:10 -0400
Subject: [PATCH 07/25] Rename test/utils.cuh into test/utils.hh

---
 test/utils.hh | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)
 create mode 100644 test/utils.hh

diff --git a/test/utils.hh b/test/utils.hh
new file mode 100644
index 00000000..1ee0d848
--- /dev/null
+++ b/test/utils.hh
@@ -0,0 +1,17 @@
+#ifndef UTILS_HH
+#define UTILS_HH
+
+#include "blas.hh"
+
+namespace blas {
+
+template <typename src_t, typename dst_t = src_t>
+void copy_matrix(
+    int m, int n,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst,
+    blas::Queue &queue);
+
+} // namespace blas
+
+#endif // UTILS_HH

From 1c99c49ae50cbb0c914a5f97ca6918548918eaac Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 4 Oct 2023 17:53:11 -0400
Subject: [PATCH 08/25] Add hip files that got generated from cuda files.

---
 test/hip/utils.hip.cc     | 174 ++++++++++++++++++++++++++++++++++++++
 test/hip/utils.hip.cc.dep |   1 +
 test/utils.cuh            |  17 ----
 tools/md5sum.pl           |  17 ++++
 4 files changed, 192 insertions(+), 17 deletions(-)
 create mode 100644 test/hip/utils.hip.cc
 create mode 100644 test/hip/utils.hip.cc.dep
 delete mode 100644 test/utils.cuh
 create mode 100755 tools/md5sum.pl

diff --git a/test/hip/utils.hip.cc b/test/hip/utils.hip.cc
new file mode 100644
index 00000000..e42fe4c5
--- /dev/null
+++ b/test/hip/utils.hip.cc
@@ -0,0 +1,174 @@
+#include "hip/hip_runtime.h"
+#include "../utils.hh"
+
+//------------------------------------------------------------------------------
+/// @return ceil( x / y ), for integer type T.
+template <typename T>
+inline constexpr T ceildiv( T x, T y )
+{
+    return T( (x + y - 1) / y );
+}
+
+//==============================================================================
+// Overloads to enable templating.
+
+//------------------------------------------------------------------------------
+// Template implementation when C++ has default rules for conversion
+// (or no conversion needed).
+template <typename src_t, typename dst_t>
+__host__ __device__
+inline void copy_scalar( src_t src, dst_t& dst )
+{
+    dst = dst_t( src );
+}
+
+//------------------------------------------------------------------------------
+// Overloaded implementations for specific cases.
+__host__ __device__
+inline void copy_scalar( float src, __half& dst )
+{
+    dst = __float2half( src );
+}
+
+//------------------------------------------------------------------------------
+__host__ __device__
+inline void copy_scalar( double src, __half& dst )
+{
+#ifdef __NVCC__
+    dst = __double2half( src );
+#else
+    dst = __float2half( (float)src );
+#endif
+}
+
+//------------------------------------------------------------------------------
+__host__ __device__
+inline void copy_scalar( __half src, float& dst )
+{
+    dst = __half2float( src );
+}
+
+//------------------------------------------------------------------------------
+__host__ __device__
+inline void copy_scalar( __half src, double& dst )
+{
+    // no __half2double, so do in 2 steps
+    dst = double( __half2float( src ) );
+}
+
+//==============================================================================
+// GPU function.
+const int blk_x = 32;
+const int blk_y = 4;
+
+//------------------------------------------------------------------------------
+// GPU device routine, called from GPU kernel.
+//
+// Each thread-block does a blk_x by blk_y block of the matrix;
+// each thread does 1 entry in the block.
+// Because of the CUDA max y grid dimension of 65535, the entire grid is
+// repeated in the y dimension with step = gridDim.y * blockDim.y,
+// so thread (i, j) will do entries (i, j), (i, j + step), (i, j + 2*step), ...
+// The max x grid dimension is 2^31-1, so there's no need to repeat.
+// Cf. magma/magmablas/slag2h.cu
+//
+template <typename src_t, typename dst_t>
+__device__
+void copy_matrix_device(
+    int m, int n,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst )
+{
+    // Global thread index.
+    const int ti = blockIdx.x * blockDim.x + threadIdx.x;
+    const int tj = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (ti < m) {
+        for (int j = tj; j < n; j += gridDim.y * blockDim.y) {
+            copy_scalar( src[ ti + j*ld_src ],
+                         dst[ ti + j*ld_dst ] );
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// GPU kernel, called from CPU driver.
+template <typename src_t, typename dst_t>
+__global__
+void copy_matrix_kernel(
+    int m, int n,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst )
+{
+    copy_matrix_device( m, n, src, ld_src, dst, ld_dst );
+}
+
+namespace blas {
+
+//------------------------------------------------------------------------------
+// Copy m-by-n src matrix to dst matrix, with type conversion.
+template <typename src_t, typename dst_t>
+void copy_matrix(
+    int m, int n,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst,
+    blas::Queue &queue)
+{
+
+    hipStream_t stream = queue.stream();
+
+    // CUDA has max x grid dimension of 2^31-1; y, z grid dimensions of 65535.
+    dim3 threads( blk_x, blk_y );
+    dim3 blocks( ceildiv( m, blk_x ), std::min( 4, ceildiv( n, blk_y ) ) );
+
+    // printf( "%s: m %d, n %d; threads %d, %d, %d; blocks %d, %d, %d\n",
+    //         __func__,
+    //         threads.x, threads.y, threads.z,
+    //         blocks.x,  blocks.y,  blocks.z );
+
+    copy_matrix_kernel<<< blocks, threads, 0, stream >>>
+        ( m, n, src, ld_src, dst, ld_dst );
+
+    // Check that launch succeeded. This doesn't say execution will be
+    // successful, it checks only the launch.
+    blas_dev_call(
+        hipGetLastError() );
+}
+
+//------------------------------------------------------------------------------
+// Explicit instantiations.
+template <>
+void copy_matrix<float, float16>(
+    int m, int n,
+    float const* src, int ld_src,
+    float16*     dst, int ld_dst,
+    blas::Queue &queue)
+{
+    copy_matrix(
+        m, n,
+                  src, ld_src,
+        (__half*) dst, ld_dst,
+        queue );
+}
+
+template <>
+void copy_matrix<float16, float>(
+    int m, int n,
+    float16 const* src, int ld_src,
+    float*         dst, int ld_dst,
+    blas::Queue &queue)
+{
+    copy_matrix(
+        m, n,
+        (__half*) src, ld_src,
+                  dst, ld_dst,
+        queue );
+}
+
+template <>
+void copy_matrix<float, float>(
+    int m, int n,
+    float const* src, int ld_src,
+    float*       dst, int ld_dst,
+    blas::Queue &queue);
+} // namespace blas
diff --git a/test/hip/utils.hip.cc.dep b/test/hip/utils.hip.cc.dep
new file mode 100644
index 00000000..e4c0b32b
--- /dev/null
+++ b/test/hip/utils.hip.cc.dep
@@ -0,0 +1 @@
+d5e89fe517631c748f7daa8ff37e92d2  test/cuda/utils.cu
diff --git a/test/utils.cuh b/test/utils.cuh
deleted file mode 100644
index 3c481bf5..00000000
--- a/test/utils.cuh
+++ /dev/null
@@ -1,17 +0,0 @@
-#ifndef UTILS_CUH
-#define UTILS_CUH
-
-#include "blas.hh"
-
-namespace blas {
-
-template <typename src_t, typename dst_t = src_t>
-void copy_matrix(
-    int m, int n,
-    src_t const* src, int ld_src,
-    dst_t*       dst, int ld_dst,
-    blas::Queue &queue);
-
-} // namespace blas
-
-#endif // UTILS_CUH
diff --git a/tools/md5sum.pl b/tools/md5sum.pl
new file mode 100755
index 00000000..f7658bce
--- /dev/null
+++ b/tools/md5sum.pl
@@ -0,0 +1,17 @@
+#!/usr/bin/perl
+#
+# Generate md5 sums of files, with output compatible with md5sum.
+# Doesn't support any options of md5sum, though (--check, etc.).
+
+use strict;
+use Digest::MD5;
+
+foreach my $filename (@ARGV) {
+    my $file;
+    if (not open( $file, '<', $filename )) {
+        warn "$0: $filename: $!\n";
+        next;
+    }
+    binmode( $file );
+    print Digest::MD5->new->addfile( $file )->hexdigest, "  $filename\n";
+}

From 7edd364776da1cc5476e61f24f90cea5e897d5cd Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 4 Oct 2023 17:53:39 -0400
Subject: [PATCH 09/25] hgemm: Add hip support for hgemm

---
 src/rocblas_wrappers.cc | 25 +++++++++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/src/rocblas_wrappers.cc b/src/rocblas_wrappers.cc
index a0581bd0..7894cdf7 100644
--- a/src/rocblas_wrappers.cc
+++ b/src/rocblas_wrappers.cc
@@ -484,6 +484,31 @@ void copy(
 //------------------------------------------------------------------------------
 // gemm
 //------------------------------------------------------------------------------
+void gemm(
+    blas::Op transA, blas::Op transB,
+    device_blas_int m, device_blas_int n, device_blas_int k,
+    float16 alpha,
+    float16 const *dA, device_blas_int ldda,
+    float16 const *dB, device_blas_int lddb,
+    float16 beta,
+    float16       *dC, device_blas_int lddc,
+    blas::Queue& queue )
+{
+    // Cast from blas::float16 to rocblas_half
+    rocblas_half alpha_ = *(rocblas_half*)&alpha;
+    rocblas_half beta_  = *(rocblas_half*)&beta;
+
+    blas_dev_call(
+        rocblas_hgemm(
+            queue.handle(),
+            op2rocblas(transA), op2rocblas(transB),
+            m, n, k,
+            &alpha_, (rocblas_half const*)dA, ldda,
+                     (rocblas_half const*)dB, lddb,
+            &beta_,  (rocblas_half*)dC, lddc ) );
+}
+//
+//------------------------------------------------------------------------------
 void gemm(
     blas::Op transA, blas::Op transB,
     device_blas_int m, device_blas_int n, device_blas_int k,

From ffac1ae9115d34fdf211357ddc75ff6665d13e39 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 4 Oct 2023 17:56:13 -0400
Subject: [PATCH 10/25] hgemm: fix compilation after cleaning.

---
 test/cuda/utils.cu        | 6 ++++++
 test/hip/utils.hip.cc     | 6 ++++++
 test/hip/utils.hip.cc.dep | 2 +-
 test/test_gemm_device.cc  | 2 +-
 4 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/test/cuda/utils.cu b/test/cuda/utils.cu
index 01003173..57b5a38c 100644
--- a/test/cuda/utils.cu
+++ b/test/cuda/utils.cu
@@ -1,5 +1,11 @@
 #include "../utils.hh"
 
+#if defined(BLAS_HAVE_CUBLAS)
+    #include <cuda_fp16.h>
+#elif defined(BLAS_HAVE_ROCBLAS)
+    #include <hip/hip_fp16.h>
+#endif
+
 //------------------------------------------------------------------------------
 /// @return ceil( x / y ), for integer type T.
 template <typename T>
diff --git a/test/hip/utils.hip.cc b/test/hip/utils.hip.cc
index e42fe4c5..cd46eeda 100644
--- a/test/hip/utils.hip.cc
+++ b/test/hip/utils.hip.cc
@@ -1,6 +1,12 @@
 #include "hip/hip_runtime.h"
 #include "../utils.hh"
 
+#if defined(BLAS_HAVE_CUBLAS)
+    #include <hip/hip_fp16.h>
+#elif defined(BLAS_HAVE_ROCBLAS)
+    #include <hip/hip_fp16.h>
+#endif
+
 //------------------------------------------------------------------------------
 /// @return ceil( x / y ), for integer type T.
 template <typename T>
diff --git a/test/hip/utils.hip.cc.dep b/test/hip/utils.hip.cc.dep
index e4c0b32b..9b338030 100644
--- a/test/hip/utils.hip.cc.dep
+++ b/test/hip/utils.hip.cc.dep
@@ -1 +1 @@
-d5e89fe517631c748f7daa8ff37e92d2  test/cuda/utils.cu
+7277516b6e785d5947d13ce9cac5b4f4  test/cuda/utils.cu
diff --git a/test/test_gemm_device.cc b/test/test_gemm_device.cc
index 2344f108..c627a3bf 100644
--- a/test/test_gemm_device.cc
+++ b/test/test_gemm_device.cc
@@ -10,7 +10,7 @@
 #include "print_matrix.hh"
 #include "check_gemm.hh"
 
-#include "utils.cuh"
+#include "utils.hh"
 
 // -----------------------------------------------------------------------------
 template <typename TA, typename TB, typename TC>

From bd9511d50566833d73c29ed55d3f0b8fabaf83f0 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 18 Oct 2023 15:55:09 -0400
Subject: [PATCH 11/25] test: gemm change the bound by removing sqrt.

---
 test/check_gemm.hh | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/check_gemm.hh b/test/check_gemm.hh
index f7ce94b0..6c7433e9 100644
--- a/test/check_gemm.hh
+++ b/test/check_gemm.hh
@@ -64,8 +64,8 @@ void check_gemm(
     real_t work[1], Cout_norm;
     Cout_norm = lapack_lange( "f", m, n, C, ldc, work );
     error[0] = Cout_norm
-             / (sqrt( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm
-                 + 2 * abs( beta ) * Cnorm);
+             / ( ( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm
+                 + 2 * abs( beta ) * Cnorm );
     if (verbose) {
         printf( "error: ||Cout||=%.2e / (sqrt(k=%lld + 2)"
                 " * |alpha|=%.2e * ||A||=%.2e * ||B||=%.2e"

From 44a9e73c2b2e6a1cbcc02eefb30402eb122c87fb Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 18 Oct 2023 15:55:55 -0400
Subject: [PATCH 12/25] test: fix scalar type used to get the flop count in
 half gemm.

---
 test/test_gemm_device.cc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/test_gemm_device.cc b/test/test_gemm_device.cc
index c627a3bf..669e3dfa 100644
--- a/test/test_gemm_device.cc
+++ b/test/test_gemm_device.cc
@@ -357,7 +357,7 @@ void test_gemm_device_work<blas::float16,blas::float16,blas::float16>( Params& p
     queue.sync();
     time = get_wtime() - time;
 
-    double gflop = blas::Gflop< scalar_hi >::gemm( m, n, k );
+    double gflop = blas::Gflop< scalar_lo >::gemm( m, n, k );
     params.time()   = time;
     params.gflops() = gflop / time;
 

From 82752bfedee16c82abf0864871b43df6e6185e5a Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 18 Oct 2023 20:43:36 +0000
Subject: [PATCH 13/25] hgemm: Use a class float16 instead of an alias.

---
 include/blas/util.hh | 34 +++++++++++++++++++++++++++++++++-
 1 file changed, 33 insertions(+), 1 deletion(-)

diff --git a/include/blas/util.hh b/include/blas/util.hh
index c88cfe2d..c692050b 100644
--- a/include/blas/util.hh
+++ b/include/blas/util.hh
@@ -14,9 +14,41 @@
 
 #include <assert.h>
 
+#include <blas/defines.h>
+
+#ifdef BLAS_HAVE_CUBLAS
+  #include <cuda_fp16.h>
+#elif defined(BLAS_HAVE_ROCBLAS)
+  #include <hip/hip_fp16.h>
+#endif
+
+
 namespace blas {
 
-using float16 = _Float16;
+class float16 {
+
+#if BLAS_USE_ISO_FLOAT16
+  using float16_ = _Float16;
+#elif defined(BLAS_HAVE_CUBLAS)
+  using float16_ = __half;
+#elif defined(BLAS_HAVE_ROCBLAS)
+  using float16_ = rocblas__half;
+#else
+  using float16_ = uint16_t;
+#endif
+
+  float16_ value_;
+
+  public:
+
+  float16() : value_( 0.0f ) { }
+
+  float16( float v ) : value_( v ) { }
+
+  operator float() const {
+    return float( value_ );
+  }
+};
 
 inline float real(float16& a) { return float( a ); }
 inline float imag(float16& a) { return 0.0f; }

From b3c836cde61ad32f2aed7fb3ccd0f8356e9a68f7 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 18 Oct 2023 22:19:40 +0000
Subject: [PATCH 14/25] config: fix gpu_backend name issue.

---
 config/config.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/config/config.py b/config/config.py
index 8de06489..4530a458 100644
--- a/config/config.py
+++ b/config/config.py
@@ -77,7 +77,7 @@ def define( var, value=None ):
 
 # ------------------------------------------------------------------------------
 # variables to replace instead of appending/prepending
-replace_vars = ['CC', 'CXX', 'NVCC', 'FC', 'AR', 'RANLIB', 'prefix']
+replace_vars = ['CC', 'CXX', 'NVCC', 'FC', 'AR', 'RANLIB', 'prefix', 'gpu_backend']
 
 # ------------------------------------------------------------------------------
 # map file extensions to languages

From 6a181e80505f2fefd7c084cbd69bca53584f1e4d Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 18 Oct 2023 18:37:42 -0400
Subject: [PATCH 15/25] hgemm: fix compilation issue.

---
 include/blas/util.hh | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/include/blas/util.hh b/include/blas/util.hh
index c692050b..65315b95 100644
--- a/include/blas/util.hh
+++ b/include/blas/util.hh
@@ -17,9 +17,9 @@
 #include <blas/defines.h>
 
 #ifdef BLAS_HAVE_CUBLAS
-  #include <cuda_fp16.h>
+#include <cuda_fp16.h>
 #elif defined(BLAS_HAVE_ROCBLAS)
-  #include <hip/hip_fp16.h>
+#include <hip/hip_fp16.h>
 #endif
 
 
@@ -27,12 +27,12 @@ namespace blas {
 
 class float16 {
 
-#if BLAS_USE_ISO_FLOAT16
+#ifdef BLAS_USE_ISO_FLOAT16
   using float16_ = _Float16;
 #elif defined(BLAS_HAVE_CUBLAS)
   using float16_ = __half;
 #elif defined(BLAS_HAVE_ROCBLAS)
-  using float16_ = rocblas__half;
+  using float16_ = rocblas_half;
 #else
   using float16_ = uint16_t;
 #endif

From 3ced0ce31da7c7214dfa820fa9111d4cb3322e1f Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 18 Oct 2023 18:39:00 -0400
Subject: [PATCH 16/25] TMP: Add explicit compilation flag for reproducer
 purpose.

---
 include/blas/util.hh | 1 +
 1 file changed, 1 insertion(+)

diff --git a/include/blas/util.hh b/include/blas/util.hh
index 65315b95..dcf0c13d 100644
--- a/include/blas/util.hh
+++ b/include/blas/util.hh
@@ -22,6 +22,7 @@
 #include <hip/hip_fp16.h>
 #endif
 
+#define BLAS_USE_ISO_FLOAT16
 
 namespace blas {
 

From d1fd85c8518cf15787fa31e3c6da49ac40fa0591 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 29 Nov 2023 00:07:09 +0000
Subject: [PATCH 17/25] hgemm: Search _Float16 support from compiler; If so,
 the macro BLAS_USE_ISO_FLOAT16 is defined.

---
 config/config.py     | 18 ++++++++++++++++++
 configure.py         |  1 +
 include/blas/util.hh | 37 ++++++++++++++++++++-----------------
 3 files changed, 39 insertions(+), 17 deletions(-)

diff --git a/config/config.py b/config/config.py
index 4530a458..0e21ac63 100644
--- a/config/config.py
+++ b/config/config.py
@@ -615,6 +615,24 @@ def openmp( flags=['-fopenmp', '-qopenmp', '-openmp', '-omp', ''] ):
     # end
 # end
 
+#-------------------------------------------------------------------------------
+def float16( ):
+    '''
+    Tests for _Float16 support from the compiler.
+    '''
+    print_header( '_Float16 support' )
+    src = 'config/return_float16.cc'
+    cxxflags = define('USE_ISO_FLOAT16')
+    print_test( cxxflags )
+    env = {'CXXFLAGS': cxxflags}
+    (rc, out, err) = compile_run( src, env )
+    print_result( "_Float16", rc )
+    if (rc == 0):
+        environ.merge( env )
+    else:
+        print_msg( font.red( 'skipping _Float16 search' ) )
+# end
+
 #-------------------------------------------------------------------------------
 def cublas_library():
     '''
diff --git a/configure.py b/configure.py
index 00ba0900..0caa5dff 100755
--- a/configure.py
+++ b/configure.py
@@ -58,6 +58,7 @@ def main():
    #config.prog_cxx_flag( '-Werror' )
 
     config.openmp()
+    config.float16()
 
     config.lapack.blas()
     print()
diff --git a/include/blas/util.hh b/include/blas/util.hh
index dcf0c13d..f457d7c5 100644
--- a/include/blas/util.hh
+++ b/include/blas/util.hh
@@ -26,30 +26,33 @@
 
 namespace blas {
 
-class float16 {
 
 #ifdef BLAS_USE_ISO_FLOAT16
-  using float16_ = _Float16;
-#elif defined(BLAS_HAVE_CUBLAS)
-  using float16_ = __half;
-#elif defined(BLAS_HAVE_ROCBLAS)
-  using float16_ = rocblas_half;
-#else
-  using float16_ = uint16_t;
-#endif
-
-  float16_ value_;
+  using float16 = _Float16;
 
-  public:
+#elif defined(BLAS_HAVE_CUBLAS)
+  using float16 = __half;
 
-  float16() : value_( 0.0f ) { }
+#elif defined(BLAS_HAVE_ROCBLAS)
+  using float16 = rocblas_half;
 
-  float16( float v ) : value_( v ) { }
+#else
+class float16 {
+    public:
+    float16() : data_( 0.0f ) { }
+    
+    // TODO manipulate the bits here
+    float16( float v ) : data_( v ) { }
+
+    // TODO manipulate the bits here
+    operator float() const {
+        return float( data_ );
+    }
 
-  operator float() const {
-    return float( value_ );
-  }
+    private:
+        uint16_t data_;
 };
+#endif
 
 inline float real(float16& a) { return float( a ); }
 inline float imag(float16& a) { return 0.0f; }

From 8b8320bff9c7fdadc52bb56ea0af0abe9fd96913 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Tue, 28 Nov 2023 23:30:34 +0000
Subject: [PATCH 18/25] hgemm: add CPU support through MKL.

---
 include/blas/fortran.h   | 14 ++++++++++++++
 include/blas/util.hh     |  2 +-
 include/blas/wrappers.hh | 11 +++++++++++
 src/gemm.cc              | 40 ++++++++++++++++++++++++++++++++++++++++
 src/onemkl_wrappers.cc   | 21 +++++++++++++++++++++
 5 files changed, 87 insertions(+), 1 deletion(-)

diff --git a/include/blas/fortran.h b/include/blas/fortran.h
index 89973302..417eb8d6 100644
--- a/include/blas/fortran.h
+++ b/include/blas/fortran.h
@@ -903,6 +903,20 @@ void BLAS_ztrsv_base(
 // =============================================================================
 // Level 3 BLAS - Fortran prototypes
 
+#if defined(BLAS_HAVE_MKL)
+#include <mkl_types.h>
+// -----------------------------------------------------------------------------
+#define BLAS_hgemm BLAS_FORTRAN_NAME( hgemm, HGEMM )
+void BLAS_hgemm(
+    char const *transA, char const *transB,
+    blas_int const *m, blas_int const *n, blas_int const *k,
+    MKL_F16 const *alpha,
+    MKL_F16 const *A, blas_int const *lda,
+    MKL_F16 const *B, blas_int const *ldb,
+    MKL_F16 const *beta,
+    MKL_F16       *C, blas_int const *ldc );
+#endif
+
 // -----------------------------------------------------------------------------
 #define BLAS_sgemm_base BLAS_FORTRAN_NAME( sgemm, SGEMM )
 void BLAS_sgemm_base(
diff --git a/include/blas/util.hh b/include/blas/util.hh
index f457d7c5..85330b42 100644
--- a/include/blas/util.hh
+++ b/include/blas/util.hh
@@ -22,7 +22,7 @@
 #include <hip/hip_fp16.h>
 #endif
 
-#define BLAS_USE_ISO_FLOAT16
+//#define BLAS_USE_ISO_FLOAT16
 
 namespace blas {
 
diff --git a/include/blas/wrappers.hh b/include/blas/wrappers.hh
index 3afc14d0..8d68e59a 100644
--- a/include/blas/wrappers.hh
+++ b/include/blas/wrappers.hh
@@ -701,6 +701,17 @@ void trsv(
 // Level 3 BLAS
 
 //------------------------------------------------------------------------------
+void gemm(
+    blas::Layout layout,
+    blas::Op transA,
+    blas::Op transB,
+    int64_t m, int64_t n, int64_t k,
+    float16 alpha,
+    float16 const* A, int64_t lda,
+    float16 const* B, int64_t ldb,
+    float16 beta,
+    float16*       C, int64_t ldc );
+
 void gemm(
     blas::Layout layout,
     blas::Op transA,
diff --git a/src/gemm.cc b/src/gemm.cc
index 2d57509d..fe0dae5e 100644
--- a/src/gemm.cc
+++ b/src/gemm.cc
@@ -9,11 +9,33 @@
 
 #include <limits>
 
+#if defined(BLAS_HAVE_MKL)
+    #include <mkl_blas.h>
+#endif
+
 namespace blas {
 
 //==============================================================================
 namespace internal {
 
+//------------------------------------------------------------------------------
+/// Low-level overload wrapper calls Fortran, float16 version.
+/// @ingroup gemm_internal
+inline void gemm(
+    char transA, char transB,
+    blas_int m, blas_int n, blas_int k,
+    float16 alpha,
+    float16 const* A, blas_int lda,
+    float16 const* B, blas_int ldb,
+    float16 beta,
+    float16*       C, blas_int ldc )
+{
+    BLAS_hgemm( &transA, &transB, &m, &n, &k,
+        (MKL_F16*)&alpha,  (MKL_F16*)A, &lda,
+                           (MKL_F16*)B, &ldb,
+        (MKL_F16*)&beta,   (MKL_F16*)C, &ldc );
+}
+
 //------------------------------------------------------------------------------
 /// Low-level overload wrapper calls Fortran, float version.
 /// @ingroup gemm_internal
@@ -179,6 +201,24 @@ void gemm(
 // When calling a template, all the templated arguments (e.g., scalar_t)
 // must match types exactly.
 
+//------------------------------------------------------------------------------
+/// CPU, float16 version.
+/// @ingroup gemm
+void gemm(
+    blas::Layout layout,
+    blas::Op transA,
+    blas::Op transB,
+    int64_t m, int64_t n, int64_t k,
+    float16 alpha,
+    float16 const* A, int64_t lda,
+    float16 const* B, int64_t ldb,
+    float16 beta,
+    float16*       C, int64_t ldc )
+{
+    impl::gemm( layout, transA, transB, m, n, k,
+                alpha, A, lda, B, ldb, beta, C, ldc );
+}
+
 //------------------------------------------------------------------------------
 /// CPU, float version.
 /// @ingroup gemm
diff --git a/src/onemkl_wrappers.cc b/src/onemkl_wrappers.cc
index e1f3d05e..340eb94e 100644
--- a/src/onemkl_wrappers.cc
+++ b/src/onemkl_wrappers.cc
@@ -496,6 +496,27 @@ void copy(
 
 //------------------------------------------------------------------------------
 // gemm
+//------------------------------------------------------------------------------
+//void gemm(
+//    blas::Op transA, blas::Op transB,
+//    device_blas_int m, device_blas_int n, device_blas_int k,
+//    float16 alpha,
+//    float16 const *dA, device_blas_int ldda,
+//    float16 const *dB, device_blas_int lddb,
+//    float16 beta,
+//    float16       *dC, device_blas_int lddc,
+//    blas::Queue& queue )
+//{
+//    blas_dev_call(
+//        oneapi::mkl::blas::gemm(
+//            queue.stream(),
+//            op2onemkl( transA ), op2onemkl( transB ),
+//            m, n, k,
+//            (MKL_F16)alpha, (MKL_F16*)dA, ldda,
+//                            (MKL_F16*)dB, lddb,
+//            (MKL_F16)beta,  (MKL_F16*)dC, lddc ) );
+//}
+
 //------------------------------------------------------------------------------
 void gemm(
     blas::Op transA, blas::Op transB,

From 30cd556d626d3cf7f1e250aec2d34df2dfaa678e Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Tue, 28 Nov 2023 23:32:19 +0000
Subject: [PATCH 19/25] hgemm: add CPU test with cblas wrapper; add
 cast_onto_device util that cast from and to float16.

---
 test/cblas_wrappers.hh   |  17 ++++
 test/cuda/utils.cu       |  56 +++++++++++
 test/test_gemm.cc        | 195 +++++++++++++++++++++++++++++++++++++++
 test/test_gemm_device.cc |  11 ++-
 test/utils.hh            |   6 ++
 5 files changed, 280 insertions(+), 5 deletions(-)

diff --git a/test/cblas_wrappers.hh b/test/cblas_wrappers.hh
index 60b9a8bc..0e8c2ac6 100644
--- a/test/cblas_wrappers.hh
+++ b/test/cblas_wrappers.hh
@@ -1156,7 +1156,24 @@ cblas_syr2(
 // =============================================================================
 // Level 3 BLAS
 
+#if defined(BLAS_HAVE_MKL)
 // -----------------------------------------------------------------------------
+inline void
+cblas_gemm(
+    CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
+    int m, int n, int k,
+    blas::float16  alpha,
+    blas::float16 const *A, int lda,
+    blas::float16 const *B, int ldb,
+    blas::float16  beta,
+    blas::float16* C, int ldc )
+{
+    cblas_hgemm( layout, transA, transB, m, n, k,
+                 (MKL_F16)alpha, (MKL_F16*)A, lda, (MKL_F16*)B, ldb,
+                 (MKL_F16)beta,  (MKL_F16*)C, ldc );
+}
+#endif
+
 inline void
 cblas_gemm(
     CBLAS_LAYOUT layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
diff --git a/test/cuda/utils.cu b/test/cuda/utils.cu
index 57b5a38c..cdd2e36c 100644
--- a/test/cuda/utils.cu
+++ b/test/cuda/utils.cu
@@ -176,4 +176,60 @@ void copy_matrix<float, float>(
     float const* src, int ld_src,
     float*       dst, int ld_dst,
     blas::Queue &queue);
+
+//------------------------------------------------------------------------------
+// Move the A matrix onto device, and copy it before moving the copy back into B.
+template <typename scalar_from, typename scalar_to>
+void cast_onto_device(
+    int m, int n,
+    const scalar_from* A, int lda,
+    scalar_to*         B, int ldb,
+    blas::Queue &queue)
+{
+  scalar_from* dA;
+  scalar_to* dB;
+
+  blas_dev_call( cudaMalloc( &dA, sizeof(scalar_from) * m*n ) );
+  blas_dev_call( cudaMalloc( &dB, sizeof(scalar_to)   * m*n ) );
+
+  blas_dev_call( cudaMemcpy( dA, A, sizeof(scalar_from) * m*n, cudaMemcpyHostToDevice ) );
+
+  copy_matrix( m, n, dA, lda, dB, ldb, queue );
+  queue.sync();
+
+  blas_dev_call( cudaMemcpy( B, dB, sizeof(scalar_to) * m*n, cudaMemcpyDeviceToHost ) );
+
+  blas_dev_call( cudaFree( dA ) );
+  blas_dev_call( cudaFree( dB ) );
+}
+
+//------------------------------------------------------------------------------
+// Explicit instantiations.
+template <>
+void cast_onto_device<float16, float>(
+    int m, int n,
+    const float16* A, int lda,
+    float*         B, int ldb,
+    blas::Queue &queue)
+{
+  cast_onto_device(
+      m, n,
+      (__half*) A, lda,
+                B, ldb,
+      queue );
+}
+
+template <>
+void cast_onto_device<float, float16>(
+    int m, int n,
+    const float* A, int lda,
+    float16*     B, int ldb,
+    blas::Queue &queue)
+{
+  cast_onto_device(
+      m, n,
+                A, lda,
+      (__half*) B, ldb,
+      queue );
+}
 } // namespace blas
diff --git a/test/test_gemm.cc b/test/test_gemm.cc
index 55605ed3..091ebb72 100644
--- a/test/test_gemm.cc
+++ b/test/test_gemm.cc
@@ -10,6 +10,8 @@
 #include "print_matrix.hh"
 #include "check_gemm.hh"
 
+#include "utils.hh"
+
 // -----------------------------------------------------------------------------
 template <typename TA, typename TB, typename TC>
 void test_gemm_work( Params& params, bool run )
@@ -170,10 +172,203 @@ void test_gemm_work( Params& params, bool run )
     delete[] Cref;
 }
 
+// -----------------------------------------------------------------------------
+template <>
+void test_gemm_work<blas::float16, blas::float16, blas::float16>(
+    Params& params, bool run )
+{
+    using namespace testsweeper;
+    using std::real;
+    using std::imag;
+    using blas::Op;
+    using blas::Layout;
+    using scalar_hi = float;
+    using scalar_lo = blas::float16;
+    using real_t   = blas::real_type< scalar_hi >;
+
+    // get & mark input values
+    blas::Layout layout = params.layout();
+    blas::Op transA  = params.transA();
+    blas::Op transB  = params.transB();
+    scalar_lo alpha  = params.alpha();
+    scalar_lo beta   = params.beta();
+    int64_t m        = params.dim.m();
+    int64_t n        = params.dim.n();
+    int64_t k        = params.dim.k();
+    int64_t align    = params.align();
+    int64_t verbose  = params.verbose();
+
+    // mark non-standard output values
+    params.gflops();
+    params.ref_time();
+    params.ref_gflops();
+
+    if (! run)
+        return;
+
+    if (blas::get_device_count() == 0) {
+        params.msg() = "skipping: no GPU devices or no GPU support";
+        return;
+    }
+
+    int64_t device = 0;//params.device();
+    // device specifics
+    blas::Queue queue( device );
+
+    // setup
+    int64_t Am = (transA == Op::NoTrans ? m : k);
+    int64_t An = (transA == Op::NoTrans ? k : m);
+    int64_t Bm = (transB == Op::NoTrans ? k : n);
+    int64_t Bn = (transB == Op::NoTrans ? n : k);
+    int64_t Cm = m;
+    int64_t Cn = n;
+    if (layout == Layout::RowMajor) {
+        std::swap( Am, An );
+        std::swap( Bm, Bn );
+        std::swap( Cm, Cn );
+    }
+    int64_t lda = roundup( Am, align );
+    int64_t ldb = roundup( Bm, align );
+    int64_t ldc = roundup( Cm, align );
+    size_t size_A = size_t(lda)*An;
+    size_t size_B = size_t(ldb)*Bn;
+    size_t size_C = size_t(ldc)*Cn;
+    blas::float16* A_lo    = new blas::float16[ size_A ];
+    blas::float16* B_lo    = new blas::float16[ size_B ];
+    blas::float16* C_lo    = new blas::float16[ size_C ];
+    float* A_hi = new float[ size_A ];
+    float* B_hi = new float[ size_B ];
+    float* C_hi = new float[ size_C ];
+    float* Cref = new float[ size_C ];
+
+    int64_t idist = 1;
+    int iseed[4] = { 0, 0, 0, 1 };
+    lapack_larnv( idist, iseed, size_A, A_hi );
+    lapack_larnv( idist, iseed, size_B, B_hi );
+    lapack_larnv( idist, iseed, size_C, C_hi );
+    lapack_lacpy( "g", Cm, Cn, C_hi, ldc, Cref, ldc );
+
+    // Convert float->float16
+    blas::cast_onto_device( Am, An, A_hi, lda, A_lo, lda, queue );
+    blas::cast_onto_device( Bm, Bn, B_hi, ldb, B_lo, ldb, queue );
+    blas::cast_onto_device( Cm, Cn, C_hi, ldc, C_lo, ldc, queue );
+    queue.sync();
+
+    // norms for error check
+    real_t work[1];
+    real_t Anorm = lapack_lange( "f", Am, An, A_hi, lda, work );
+    real_t Bnorm = lapack_lange( "f", Bm, Bn, B_hi, ldb, work );
+    real_t Cnorm = lapack_lange( "f", Cm, Cn, C_hi, ldc, work );
+
+    // test error exits
+    assert_throw( blas::gemm( Layout(0), transA, transB,  m,  n,  k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( layout,    Op(0),  transB,  m,  n,  k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( layout,    transA, Op(0),   m,  n,  k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( layout,    transA, transB, -1,  n,  k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( layout,    transA, transB,  m, -1,  k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( layout,    transA, transB,  m,  n, -1, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans,   Op::NoTrans, m, n, k, alpha, A_hi, m-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::Trans,     Op::NoTrans, m, n, k, alpha, A_hi, k-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, A_hi, k-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans,   Op::NoTrans, m, n, k, alpha, A_hi, k-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::Trans,     Op::NoTrans, m, n, k, alpha, A_hi, m-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, A_hi, m-1, B_hi, ldb, beta, C_hi, ldc ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans,   m, n, k, alpha, A_hi, lda, B_hi, k-1, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::Trans,     m, n, k, alpha, A_hi, lda, B_hi, n-1, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, A_hi, lda, B_hi, n-1, beta, C_hi, ldc ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans,   m, n, k, alpha, A_hi, lda, B_hi, n-1, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::Trans,     m, n, k, alpha, A_hi, lda, B_hi, k-1, beta, C_hi, ldc ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, A_hi, lda, B_hi, k-1, beta, C_hi, ldc ), blas::Error );
+
+    assert_throw( blas::gemm( Layout::ColMajor, transA, transB, m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, m-1 ), blas::Error );
+    assert_throw( blas::gemm( Layout::RowMajor, transA, transB, m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, C_hi, n-1 ), blas::Error );
+
+    if (verbose >= 1) {
+        printf( "\n"
+                "A Am=%5lld, An=%5lld, lda=%5lld, size=%10lld, norm %.2e\n"
+                "B Bm=%5lld, Bn=%5lld, ldb=%5lld, size=%10lld, norm %.2e\n"
+                "C Cm=%5lld, Cn=%5lld, ldc=%5lld, size=%10lld, norm %.2e\n",
+                llong( Am ), llong( An ), llong( lda ), llong( size_A ), Anorm,
+                llong( Bm ), llong( Bn ), llong( ldb ), llong( size_B ), Bnorm,
+                llong( Cm ), llong( Cn ), llong( ldc ), llong( size_C ), Cnorm );
+    }
+    if (verbose >= 2) {
+        printf( "alpha = %.4e + %.4ei; beta = %.4e + %.4ei;\n",
+                real(alpha), imag(alpha),
+                real(beta),  imag(beta) );
+        printf( "A = "    ); print_matrix( Am, An, A_hi, lda );
+        printf( "B = "    ); print_matrix( Bm, Bn, B_hi, ldb );
+        printf( "C = "    ); print_matrix( Cm, Cn, C_hi, ldc );
+    }
+
+    // run test
+    testsweeper::flush_cache( params.cache() );
+    double time = get_wtime();
+    blas::gemm( layout, transA, transB, m, n, k,
+                alpha, A_lo, lda, B_lo, ldb, beta, C_lo, ldc );
+    time = get_wtime() - time;
+
+    double gflop = blas::Gflop< scalar_lo >::gemm( m, n, k );
+    params.time()   = time;
+    params.gflops() = gflop / time;
+
+    // Convert float16->float
+    blas::cast_onto_device( Cm, Cn, C_lo, ldc, C_hi, ldc, queue );
+    queue.sync();
+
+    if (verbose >= 2) {
+        printf( "C2 = " ); print_matrix( Cm, Cn, C_hi, ldc );
+    }
+
+    if (params.ref() == 'y' || params.check() == 'y') {
+        // run reference
+        testsweeper::flush_cache( params.cache() );
+        time = get_wtime();
+        cblas_gemm( cblas_layout_const(layout),
+                    cblas_trans_const(transA),
+                    cblas_trans_const(transB),
+                    m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, Cref, ldc );
+        time = get_wtime() - time;
+
+        params.ref_time()   = time;
+        params.ref_gflops() = gflop / time;
+
+        if (verbose >= 2) {
+            printf( "Cref = " ); print_matrix( Cm, Cn, Cref, ldc );
+        }
+
+        // check error compared to reference
+        real_t error;
+        bool okay;
+        check_gemm<float, blas::float16>(
+            Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm,
+            Cref, ldc, C_hi, ldc, verbose, &error, &okay );
+        params.error() = error;
+        params.okay() = okay;
+    }
+
+    delete[] A_lo;
+    delete[] B_lo;
+    delete[] C_lo;
+    delete[] A_hi;
+    delete[] B_hi;
+    delete[] C_hi;
+    delete[] Cref;
+}
+
 // -----------------------------------------------------------------------------
 void test_gemm( Params& params, bool run )
 {
     switch (params.datatype()) {
+        case testsweeper::DataType::Half:
+            test_gemm_work< blas::float16, blas::float16, blas::float16 >(
+                params, run );
+            break;
+
         case testsweeper::DataType::Single:
             test_gemm_work< float, float, float >( params, run );
             break;
diff --git a/test/test_gemm_device.cc b/test/test_gemm_device.cc
index 669e3dfa..4ee0ce6d 100644
--- a/test/test_gemm_device.cc
+++ b/test/test_gemm_device.cc
@@ -288,9 +288,9 @@ void test_gemm_device_work<blas::float16,blas::float16,blas::float16>( Params& p
     lapack_larnv( idist, iseed, size_C, C_hi );
     lapack_lacpy( "g", Cm, Cn, C_hi, ldc, Cref, ldc );
 
-    blas::device_copy_matrix(Am, An, A_hi, lda, dA_hi, lda, queue);
-    blas::device_copy_matrix(Bm, Bn, B_hi, ldb, dB_hi, ldb, queue);
-    blas::device_copy_matrix(Cm, Cn, C_hi, ldc, dC_hi, ldc, queue);
+    blas::device_copy_matrix( Am, An, A_hi, lda, dA_hi, lda, queue );
+    blas::device_copy_matrix( Bm, Bn, B_hi, ldb, dB_hi, ldb, queue );
+    blas::device_copy_matrix( Cm, Cn, C_hi, ldc, dC_hi, ldc, queue );
 
     // Convert float->float16
     blas::copy_matrix( Am, An, dA_hi, lda, dA_lo, lda, queue );
@@ -390,8 +390,9 @@ void test_gemm_device_work<blas::float16,blas::float16,blas::float16>( Params& p
         // check error compared to reference
         real_t error;
         bool okay;
-        check_gemm<float, blas::float16>( Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm,
-                    Cref, ldc, C_hi, ldc, verbose, &error, &okay );
+        check_gemm<float, blas::float16>(
+            Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm,
+            Cref, ldc, C_hi, ldc, verbose, &error, &okay );
         params.error() = error;
         params.okay() = okay;
     }
diff --git a/test/utils.hh b/test/utils.hh
index 1ee0d848..8e038f44 100644
--- a/test/utils.hh
+++ b/test/utils.hh
@@ -12,6 +12,12 @@ void copy_matrix(
     dst_t*       dst, int ld_dst,
     blas::Queue &queue);
 
+template <typename scalar_from, typename scalar_to>
+void cast_onto_device(
+    int m, int n,
+    const scalar_from* A, int lda,
+    scalar_to*         B, int ldb,
+    blas::Queue &queue);
 } // namespace blas
 
 #endif // UTILS_HH

From d6864512aaa1742fb2b779c7678f35be8cd77499 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Fri, 1 Dec 2023 15:34:26 +0000
Subject: [PATCH 20/25] float16: update configure search and macro definition.

---
 config/config.py | 2 +-
 configure.py     | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/config/config.py b/config/config.py
index 0e21ac63..e151ca70 100644
--- a/config/config.py
+++ b/config/config.py
@@ -622,7 +622,7 @@ def float16( ):
     '''
     print_header( '_Float16 support' )
     src = 'config/return_float16.cc'
-    cxxflags = define('USE_ISO_FLOAT16')
+    cxxflags = define('HAVE_ISO_FLOAT16')
     print_test( cxxflags )
     env = {'CXXFLAGS': cxxflags}
     (rc, out, err) = compile_run( src, env )
diff --git a/configure.py b/configure.py
index 0caa5dff..6491a00d 100755
--- a/configure.py
+++ b/configure.py
@@ -57,9 +57,10 @@ def main():
    #config.prog_cxx_flag( '-Wconversion' )
    #config.prog_cxx_flag( '-Werror' )
 
-    config.openmp()
     config.float16()
 
+    config.openmp()
+
     config.lapack.blas()
     print()
     config.lapack.blas_float_return()

From 5a30ad8fc111a9767d8cb58c021390388ae6567e Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Fri, 1 Dec 2023 15:38:52 +0000
Subject: [PATCH 21/25] hgemm: Fake casting in cublas_wrapper through pointer
 casting.

---
 src/cublas_wrappers.cc | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/cublas_wrappers.cc b/src/cublas_wrappers.cc
index 260ba763..33b73d59 100644
--- a/src/cublas_wrappers.cc
+++ b/src/cublas_wrappers.cc
@@ -489,8 +489,8 @@ void gemm(
     float16       *dC, device_blas_int lddc,
     blas::Queue& queue )
 {
-    __half alpha_ = __half( alpha );
-    __half beta_  = __half( beta );
+    __half alpha_ = *((__half*)&alpha);
+    __half beta_  = *((__half*)&beta);
 
     blas_dev_call(
         cublasHgemm(

From f4fed1aae9028c3ada4eb876063ac63512f145c0 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Fri, 1 Dec 2023 15:44:37 +0000
Subject: [PATCH 22/25] test: in hgemm, add cpu casting support, remove
 cast_onto_device routine; Fix compilation issue.

---
 test/cuda/utils.cu       | 55 ----------------------------------------
 test/test_gemm.cc        | 25 ++++++------------
 test/test_gemm_device.cc |  4 +--
 test/utils.hh            | 11 ++++----
 4 files changed, 15 insertions(+), 80 deletions(-)

diff --git a/test/cuda/utils.cu b/test/cuda/utils.cu
index cdd2e36c..ded6adc6 100644
--- a/test/cuda/utils.cu
+++ b/test/cuda/utils.cu
@@ -177,59 +177,4 @@ void copy_matrix<float, float>(
     float*       dst, int ld_dst,
     blas::Queue &queue);
 
-//------------------------------------------------------------------------------
-// Move the A matrix onto device, and copy it before moving the copy back into B.
-template <typename scalar_from, typename scalar_to>
-void cast_onto_device(
-    int m, int n,
-    const scalar_from* A, int lda,
-    scalar_to*         B, int ldb,
-    blas::Queue &queue)
-{
-  scalar_from* dA;
-  scalar_to* dB;
-
-  blas_dev_call( cudaMalloc( &dA, sizeof(scalar_from) * m*n ) );
-  blas_dev_call( cudaMalloc( &dB, sizeof(scalar_to)   * m*n ) );
-
-  blas_dev_call( cudaMemcpy( dA, A, sizeof(scalar_from) * m*n, cudaMemcpyHostToDevice ) );
-
-  copy_matrix( m, n, dA, lda, dB, ldb, queue );
-  queue.sync();
-
-  blas_dev_call( cudaMemcpy( B, dB, sizeof(scalar_to) * m*n, cudaMemcpyDeviceToHost ) );
-
-  blas_dev_call( cudaFree( dA ) );
-  blas_dev_call( cudaFree( dB ) );
-}
-
-//------------------------------------------------------------------------------
-// Explicit instantiations.
-template <>
-void cast_onto_device<float16, float>(
-    int m, int n,
-    const float16* A, int lda,
-    float*         B, int ldb,
-    blas::Queue &queue)
-{
-  cast_onto_device(
-      m, n,
-      (__half*) A, lda,
-                B, ldb,
-      queue );
-}
-
-template <>
-void cast_onto_device<float, float16>(
-    int m, int n,
-    const float* A, int lda,
-    float16*     B, int ldb,
-    blas::Queue &queue)
-{
-  cast_onto_device(
-      m, n,
-                A, lda,
-      (__half*) B, ldb,
-      queue );
-}
 } // namespace blas
diff --git a/test/test_gemm.cc b/test/test_gemm.cc
index 091ebb72..549ac95e 100644
--- a/test/test_gemm.cc
+++ b/test/test_gemm.cc
@@ -180,6 +180,8 @@ void test_gemm_work<blas::float16, blas::float16, blas::float16>(
     using namespace testsweeper;
     using std::real;
     using std::imag;
+    using blas::real;
+    using blas::imag;
     using blas::Op;
     using blas::Layout;
     using scalar_hi = float;
@@ -190,8 +192,8 @@ void test_gemm_work<blas::float16, blas::float16, blas::float16>(
     blas::Layout layout = params.layout();
     blas::Op transA  = params.transA();
     blas::Op transB  = params.transB();
-    scalar_lo alpha  = params.alpha();
-    scalar_lo beta   = params.beta();
+    scalar_lo alpha  = (scalar_lo)params.alpha();
+    scalar_lo beta   = (scalar_lo)params.beta();
     int64_t m        = params.dim.m();
     int64_t n        = params.dim.n();
     int64_t k        = params.dim.k();
@@ -206,15 +208,6 @@ void test_gemm_work<blas::float16, blas::float16, blas::float16>(
     if (! run)
         return;
 
-    if (blas::get_device_count() == 0) {
-        params.msg() = "skipping: no GPU devices or no GPU support";
-        return;
-    }
-
-    int64_t device = 0;//params.device();
-    // device specifics
-    blas::Queue queue( device );
-
     // setup
     int64_t Am = (transA == Op::NoTrans ? m : k);
     int64_t An = (transA == Op::NoTrans ? k : m);
@@ -249,10 +242,9 @@ void test_gemm_work<blas::float16, blas::float16, blas::float16>(
     lapack_lacpy( "g", Cm, Cn, C_hi, ldc, Cref, ldc );
 
     // Convert float->float16
-    blas::cast_onto_device( Am, An, A_hi, lda, A_lo, lda, queue );
-    blas::cast_onto_device( Bm, Bn, B_hi, ldb, B_lo, ldb, queue );
-    blas::cast_onto_device( Cm, Cn, C_hi, ldc, C_lo, ldc, queue );
-    queue.sync();
+    blas::copy_matrix( Am, An, A_hi, lda, A_lo, lda );
+    blas::copy_matrix( Bm, Bn, B_hi, ldb, B_lo, ldb );
+    blas::copy_matrix( Cm, Cn, C_hi, ldc, C_lo, ldc );
 
     // norms for error check
     real_t work[1];
@@ -317,8 +309,7 @@ void test_gemm_work<blas::float16, blas::float16, blas::float16>(
     params.gflops() = gflop / time;
 
     // Convert float16->float
-    blas::cast_onto_device( Cm, Cn, C_lo, ldc, C_hi, ldc, queue );
-    queue.sync();
+    blas::copy_matrix( Cm, Cn, C_lo, ldc, C_hi, ldc );
 
     if (verbose >= 2) {
         printf( "C2 = " ); print_matrix( Cm, Cn, C_hi, ldc );
diff --git a/test/test_gemm_device.cc b/test/test_gemm_device.cc
index 4ee0ce6d..3ef74701 100644
--- a/test/test_gemm_device.cc
+++ b/test/test_gemm_device.cc
@@ -217,8 +217,8 @@ void test_gemm_device_work<blas::float16,blas::float16,blas::float16>( Params& p
     blas::Layout layout = params.layout();
     blas::Op transA     = params.transA();
     blas::Op transB     = params.transB();
-    scalar_lo alpha     = params.alpha();
-    scalar_lo beta      = params.beta();
+    scalar_lo alpha     = (scalar_lo)params.alpha();
+    scalar_lo beta      = (scalar_lo)params.beta();
     int64_t m           = params.dim.m();
     int64_t n           = params.dim.n();
     int64_t k           = params.dim.k();
diff --git a/test/utils.hh b/test/utils.hh
index 8e038f44..748ff3d2 100644
--- a/test/utils.hh
+++ b/test/utils.hh
@@ -9,14 +9,13 @@ template <typename src_t, typename dst_t = src_t>
 void copy_matrix(
     int m, int n,
     src_t const* src, int ld_src,
-    dst_t*       dst, int ld_dst,
-    blas::Queue &queue);
+    dst_t*       dst, int ld_dst);
 
-template <typename scalar_from, typename scalar_to>
-void cast_onto_device(
+template <typename src_t, typename dst_t = src_t>
+void copy_matrix(
     int m, int n,
-    const scalar_from* A, int lda,
-    scalar_to*         B, int ldb,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst,
     blas::Queue &queue);
 } // namespace blas
 

From d0d2867a469f78ec4a3a0dc4c4c9c3885aa76d2c Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Tue, 5 Dec 2023 22:41:31 +0000
Subject: [PATCH 23/25] float16: add casting routines from/to fp16 to/from
 fp32.

---
 include/blas/util.hh | 103 +++++++++++++++++++++++++++++++++++++++----
 test/utils.cc        |  33 ++++++++++++++
 2 files changed, 127 insertions(+), 9 deletions(-)
 create mode 100644 test/utils.cc

diff --git a/include/blas/util.hh b/include/blas/util.hh
index 85330b42..c16bb032 100644
--- a/include/blas/util.hh
+++ b/include/blas/util.hh
@@ -20,14 +20,13 @@
 #include <cuda_fp16.h>
 #elif defined(BLAS_HAVE_ROCBLAS)
 #include <hip/hip_fp16.h>
+#elif defined(BLAS_HAVE_MKL)
+#include <mkl_types.h>
 #endif
 
-//#define BLAS_USE_ISO_FLOAT16
-
 namespace blas {
 
-
-#ifdef BLAS_USE_ISO_FLOAT16
+#ifdef BLAS_HAVE_ISO_FLOAT16
   using float16 = _Float16;
 
 #elif defined(BLAS_HAVE_CUBLAS)
@@ -38,19 +37,105 @@ namespace blas {
 
 #else
 class float16 {
-    public:
+#if defined(BLAS_HAVE_MKL)
+    using float16_ = MKL_F16;
+#else
+    using float16_ = uint16_t;
+#endif
+
+public:
     float16() : data_( 0.0f ) { }
     
     // TODO manipulate the bits here
-    float16( float v ) : data_( v ) { }
+    float16( float v ) { data_ = float_to_float16( v ); }
 
     // TODO manipulate the bits here
     operator float() const {
-        return float( data_ );
+        return float16_to_float( data_ );
     }
 
-    private:
-        uint16_t data_;
+private:
+    float16_ data_;
+
+    typedef union {
+      float16_  data;
+      struct {
+        unsigned int frac : 10;
+        unsigned int exp  :  5;
+        unsigned int sign :  1;
+      } bits;
+    } float16_repr_data_t;
+
+    typedef union {
+      float data;
+      struct {
+        unsigned int frac : 23;
+        unsigned int exp  :  8;
+        unsigned int sign :  1;
+      } bits;
+    } float_repr_data_t;
+
+    static float float16_to_float(float16_ x) {
+        float16_repr_data_t src;
+        float_repr_data_t dst;
+
+        src.data = x;
+        dst.data = 0;
+        dst.bits.sign = src.bits.sign;
+
+        if (src.bits.exp == 0x01fU) {
+            dst.bits.exp  = 0xffU;
+            if (src.bits.frac > 0) {
+                dst.bits.frac = ((src.bits.frac | 0x200U) << 13);
+            }
+        } else if (src.bits.exp > 0x00U) {
+            dst.bits.exp  = src.bits.exp + ((1 << 7) - (1 << 4));
+            dst.bits.frac = (src.bits.frac << 13);
+        } else {
+            unsigned int v = (src.bits.frac << 13);
+
+            if (v > 0) {
+                dst.bits.exp = 0x71;
+                while ((v & 0x800000UL) == 0) {
+                    dst.bits.exp --;
+                    v <<= 1;
+                }
+                dst.bits.frac = v;
+            }
+        }
+
+        return dst.data;
+    }
+
+    static float16_ float_to_float16(float x) {
+        float_repr_data_t src;
+        float16_repr_data_t dst;
+
+        src.data = x;
+        dst.data = 0;
+        dst.bits.sign = src.bits.sign;
+
+        if (src.bits.exp == 0x0ffU) {
+            dst.bits.exp  = 0x01fU;
+            dst.bits.frac = (src.bits.frac >> 13);
+            if (src.bits.frac > 0) dst.bits.frac |= 0x200U;
+        } else if (src.bits.exp >= 0x08fU) {
+            dst.bits.exp  = 0x01fU;
+            dst.bits.frac = 0x000U;
+        } else if (src.bits.exp >= 0x071U){
+            dst.bits.exp  = src.bits.exp + ((1 << 4) - (1 << 7));
+            dst.bits.frac = (src.bits.frac >> 13);
+        } else if (src.bits.exp >= 0x067U){
+            dst.bits.exp  = 0x000;
+            if (src.bits.frac > 0) {
+                dst.bits.frac = (((1U << 23) | src.bits.frac) >> 14);
+            } else {
+                dst.bits.frac = 1;
+            }
+        }
+
+        return dst.data;
+    }
 };
 #endif
 
diff --git a/test/utils.cc b/test/utils.cc
new file mode 100644
index 00000000..798d6b45
--- /dev/null
+++ b/test/utils.cc
@@ -0,0 +1,33 @@
+#include "utils.hh"
+
+namespace blas {
+
+template <typename src_t, typename dst_t>
+void copy_matrix(
+    int m, int n,
+    src_t const* src, int ld_src,
+    dst_t*       dst, int ld_dst)
+{
+  #pragma omp parallel for collapse(2)
+  for (int i = 0; i < m; ++i) {
+    for (int j = 0; j < n; ++j) {
+      dst[i + j*ld_dst] = (dst_t)src[i + j*ld_src];
+    }
+  }
+}
+
+//------------------------------------------------------------------------------
+// Explicit instantiations.
+template
+void copy_matrix<float, float16>(
+    int m, int n,
+    float const* src, int ld_src,
+    float16*     dst, int ld_dst);
+
+template
+void copy_matrix<float16, float>(
+    int m, int n,
+    float16 const* src, int ld_src,
+    float*         dst, int ld_dst);
+
+} // namespace blas

From 70d893bee619a757af41279bdc0399eb7d6f82e2 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Wed, 6 Dec 2023 17:55:46 +0000
Subject: [PATCH 24/25] float16: add missing config file.

---
 config/return_float16.cc | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)
 create mode 100644 config/return_float16.cc

diff --git a/config/return_float16.cc b/config/return_float16.cc
new file mode 100644
index 00000000..928246b0
--- /dev/null
+++ b/config/return_float16.cc
@@ -0,0 +1,19 @@
+// Copyright (c) 2017-2022, University of Tennessee. All rights reserved.
+// SPDX-License-Identifier: BSD-3-Clause
+// This program is free software: you can redistribute it and/or modify it under
+// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
+
+#include <stdio.h>
+
+#include "config.h"
+
+//------------------------------------------------------------------------------
+int main()
+{
+    _Float16 a = 0.1;
+    _Float16 b = 0.2;
+    _Float16 c = a + b;
+
+    printf( "%f + %f = %f -- expected 0.3\n", (float)a, (float)b, (float)c );
+    return 0;
+}

From 8678c8267e32df8b701317f11c1d967fb044b165 Mon Sep 17 00:00:00 2001
From: yardras <sebastien.cayrols@icl.utk.edu>
Date: Thu, 14 Dec 2023 13:56:04 -0500
Subject: [PATCH 25/25] hgemm: enable CPU hgemm only when MKL is provided.

---
 src/gemm.cc | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/src/gemm.cc b/src/gemm.cc
index fe0dae5e..bfc0d8fa 100644
--- a/src/gemm.cc
+++ b/src/gemm.cc
@@ -30,10 +30,12 @@ inline void gemm(
     float16 beta,
     float16*       C, blas_int ldc )
 {
+#ifdef BLAS_HAVE_MKL
     BLAS_hgemm( &transA, &transB, &m, &n, &k,
         (MKL_F16*)&alpha,  (MKL_F16*)A, &lda,
                            (MKL_F16*)B, &ldb,
         (MKL_F16*)&beta,   (MKL_F16*)C, &ldc );
+#endif
 }
 
 //------------------------------------------------------------------------------