Skip to content

Commit c69f71f

Browse files
committed
hgemm: add hgemm tester; conversion routine from and to half; Nvidia support for now.
1 parent ebe5f98 commit c69f71f

File tree

6 files changed

+412
-4
lines changed

6 files changed

+412
-4
lines changed

GNUmakefile

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ make.inc:
3737
RANLIB ?= ranlib
3838
prefix ?= /opt/slate
3939

40+
NVCC ?= nvcc
41+
42+
NVCCFLAGS += -O3 -std=c++11 --compiler-options '-Wall -Wno-unused-function'
43+
4044
abs_prefix := ${abspath ${prefix}}
4145

4246
# Default LD=ld won't work; use CXX. Can override in make.inc or environment.
@@ -77,7 +81,7 @@ lib_src = $(wildcard src/*.cc)
7781
lib_obj = $(addsuffix .o, $(basename $(lib_src)))
7882
dep += $(addsuffix .d, $(basename $(lib_src)))
7983

80-
tester_src = $(wildcard test/*.cc)
84+
tester_src = $(wildcard test/*.cc test/*.cu)
8185
tester_obj = $(addsuffix .o, $(basename $(tester_src)))
8286
dep += $(addsuffix .d, $(basename $(tester_src)))
8387

@@ -123,6 +127,7 @@ src/version.o: .id
123127
#-------------------------------------------------------------------------------
124128
# BLAS++ specific flags and libraries
125129
CXXFLAGS += -I./include
130+
NVCCFLAGS += -I./include
126131

127132
# additional flags and libraries for testers
128133
$(tester_obj): CXXFLAGS += -I$(testsweeper_dir)
@@ -289,6 +294,9 @@ hooks: ${hooks}
289294
%.o: %.cc
290295
$(CXX) $(CXXFLAGS) -c $< -o $@
291296

297+
%.o: %.cu
298+
$(NVCC) $(NVCCFLAGS) -c $< -o $@
299+
292300
# preprocess source
293301
%.i: %.cc
294302
$(CXX) $(CXXFLAGS) -I$(testsweeper_dir) -E $< -o $@

test/check_gemm.hh

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,21 @@
1313

1414
#include <limits>
1515

16+
namespace std {
17+
template <> class numeric_limits<half> {
18+
public:
19+
static half epsilon() {
20+
// Value coming from MAGMA testing/testing_hgemm.cpp
21+
return half( 0.00097656 );
22+
}
23+
};
24+
}; // namespace std
25+
1626
// -----------------------------------------------------------------------------
1727
// Computes error for multiplication with general matrix result.
1828
// Covers dot, gemv, ger, geru, gemm, symv, hemv, symm, trmv, trsv?, trmm, trsm?.
1929
// Cnorm is norm of original C, before multiplication operation.
20-
template <typename T>
30+
template <typename T, typename err_prec = T>
2131
void check_gemm(
2232
int64_t m, int64_t n, int64_t k,
2333
T alpha,
@@ -68,7 +78,7 @@ void check_gemm(
6878
error[0] /= 2*sqrt(2);
6979
}
7080

71-
real_t u = 0.5 * std::numeric_limits< real_t >::epsilon();
81+
real_t u = 0.5 * std::numeric_limits<blas::real_type<err_prec>>::epsilon();
7282
*okay = (error[0] < u);
7383

7484
#undef C

test/test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ Params::Params():
212212

213213
// ----- routine parameters
214214
// name, w, type, def, char2enum, enum2char, enum2str, help
215-
datatype ( "type", 4, ParamType::List, DataType::Double, char2datatype, datatype2char, datatype2str, "s=single (float), d=double, c=complex-single, z=complex-double" ),
215+
datatype ( "type", 4, ParamType::List, DataType::Double, char2datatype, datatype2char, datatype2str, "h=half, s=single (float), d=double, c=complex-single, z=complex-double" ),
216216
layout ( "layout", 6, ParamType::List, blas::Layout::ColMajor, blas::char2layout, blas::layout2char, blas::layout2str, "layout: r=row major, c=column major" ),
217217
format ( "format", 6, ParamType::List, blas::Format::LAPACK, blas::char2format, blas::format2char, blas::format2str, "format: l=lapack, t=tile" ),
218218
side ( "side", 6, ParamType::List, blas::Side::Left, blas::char2side, blas::side2char, blas::side2str, "side: l=left, r=right" ),

test/test_gemm_device.cc

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "print_matrix.hh"
1111
#include "check_gemm.hh"
1212

13+
#include "utils.cuh"
14+
1315
// -----------------------------------------------------------------------------
1416
template <typename TA, typename TB, typename TC>
1517
void test_gemm_device_work( Params& params, bool run )
@@ -197,11 +199,225 @@ void test_gemm_device_work( Params& params, bool run )
197199
blas::device_free( dB, queue );
198200
blas::device_free( dC, queue );
199201
}
202+
//
203+
// -----------------------------------------------------------------------------
204+
template <>
205+
void test_gemm_device_work<half,half,half>( Params& params, bool run )
206+
{
207+
using namespace testsweeper;
208+
using std::real;
209+
using std::imag;
210+
using blas::Op;
211+
using blas::Layout;
212+
using scalar_hi = float;
213+
using scalar_lo = half;
214+
using real_t = blas::real_type< scalar_hi >;
215+
216+
// get & mark input values
217+
blas::Layout layout = params.layout();
218+
blas::Op transA = params.transA();
219+
blas::Op transB = params.transB();
220+
scalar_lo alpha = params.alpha();
221+
scalar_lo beta = params.beta();
222+
int64_t m = params.dim.m();
223+
int64_t n = params.dim.n();
224+
int64_t k = params.dim.k();
225+
int64_t device = params.device();
226+
int64_t align = params.align();
227+
int64_t verbose = params.verbose();
228+
229+
// mark non-standard output values
230+
params.gflops();
231+
params.ref_time();
232+
params.ref_gflops();
233+
234+
if (! run)
235+
return;
236+
237+
if (blas::get_device_count() == 0) {
238+
params.msg() = "skipping: no GPU devices or no GPU support";
239+
return;
240+
}
241+
242+
// setup
243+
int64_t Am = (transA == Op::NoTrans ? m : k);
244+
int64_t An = (transA == Op::NoTrans ? k : m);
245+
int64_t Bm = (transB == Op::NoTrans ? k : n);
246+
int64_t Bn = (transB == Op::NoTrans ? n : k);
247+
int64_t Cm = m;
248+
int64_t Cn = n;
249+
if (layout == Layout::RowMajor) {
250+
std::swap( Am, An );
251+
std::swap( Bm, Bn );
252+
std::swap( Cm, Cn );
253+
}
254+
int64_t lda = roundup( Am, align );
255+
int64_t ldb = roundup( Bm, align );
256+
int64_t ldc = roundup( Cm, align );
257+
size_t size_A = size_t(lda)*An;
258+
size_t size_B = size_t(ldb)*Bn;
259+
size_t size_C = size_t(ldc)*Cn;
260+
half* A_lo = new half[ size_A ];
261+
half* B_lo = new half[ size_B ];
262+
half* C_lo = new half[ size_C ];
263+
float* A_hi = new float[ size_A ];
264+
float* B_hi = new float[ size_B ];
265+
float* C_hi = new float[ size_C ];
266+
float* Cref = new float[ size_C ];
267+
268+
// device specifics
269+
blas::Queue queue( device );
270+
half* dA_lo;
271+
half* dB_lo;
272+
half* dC_lo;
273+
float* dA_hi;
274+
float* dB_hi;
275+
float* dC_hi;
276+
277+
dA_lo = blas::device_malloc<half>( size_A, queue );
278+
dB_lo = blas::device_malloc<half>( size_B, queue );
279+
dC_lo = blas::device_malloc<half>( size_C, queue );
280+
dA_hi = blas::device_malloc<float>( size_A, queue );
281+
dB_hi = blas::device_malloc<float>( size_B, queue );
282+
dC_hi = blas::device_malloc<float>( size_C, queue );
283+
284+
int64_t idist = 1;
285+
int iseed[4] = { 0, 0, 0, 1 };
286+
lapack_larnv( idist, iseed, size_A, A_hi );
287+
lapack_larnv( idist, iseed, size_B, B_hi );
288+
lapack_larnv( idist, iseed, size_C, C_hi );
289+
lapack_lacpy( "g", Cm, Cn, C_hi, ldc, Cref, ldc );
290+
291+
blas::device_copy_matrix(Am, An, A_hi, lda, dA_hi, lda, queue);
292+
blas::device_copy_matrix(Bm, Bn, B_hi, ldb, dB_hi, ldb, queue);
293+
blas::device_copy_matrix(Cm, Cn, C_hi, ldc, dC_hi, ldc, queue);
294+
295+
blas::copy_matrix( Am, An, dA_hi, lda, dA_lo, lda, queue );
296+
blas::copy_matrix( Bm, Bn, dB_hi, ldb, dB_lo, ldb, queue );
297+
blas::copy_matrix( Cm, Cn, dC_hi, ldc, dC_lo, ldc, queue );
298+
queue.sync();
299+
300+
// norms for error check
301+
real_t work[1];
302+
real_t Anorm = lapack_lange( "f", Am, An, A_hi, lda, work );
303+
real_t Bnorm = lapack_lange( "f", Bm, Bn, B_hi, ldb, work );
304+
real_t Cnorm = lapack_lange( "f", Cm, Cn, C_hi, ldc, work );
305+
306+
// test error exits
307+
assert_throw( blas::gemm( Layout(0), transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
308+
assert_throw( blas::gemm( layout, Op(0), transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
309+
assert_throw( blas::gemm( layout, transA, Op(0), m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
310+
assert_throw( blas::gemm( layout, transA, transB, -1, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
311+
assert_throw( blas::gemm( layout, transA, transB, m, -1, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
312+
assert_throw( blas::gemm( layout, transA, transB, m, n, -1, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
313+
314+
assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
315+
assert_throw( blas::gemm( Layout::ColMajor, Op::Trans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
316+
assert_throw( blas::gemm( Layout::ColMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
317+
318+
assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
319+
assert_throw( blas::gemm( Layout::RowMajor, Op::Trans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
320+
assert_throw( blas::gemm( Layout::RowMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error );
321+
322+
assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
323+
assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
324+
assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
325+
326+
assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error );
327+
assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::Trans, m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
328+
assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error );
329+
330+
assert_throw( blas::gemm( Layout::ColMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, m-1, queue ), blas::Error );
331+
assert_throw( blas::gemm( Layout::RowMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, n-1, queue ), blas::Error );
332+
333+
if (verbose >= 1) {
334+
printf( "\n"
335+
"A Am=%5lld, An=%5lld, lda=%5lld, size=%10lld, norm %.2e\n"
336+
"B Bm=%5lld, Bn=%5lld, ldb=%5lld, size=%10lld, norm %.2e\n"
337+
"C Cm=%5lld, Cn=%5lld, ldc=%5lld, size=%10lld, norm %.2e\n",
338+
llong( Am ), llong( An ), llong( lda ), llong( size_A ), Anorm,
339+
llong( Bm ), llong( Bn ), llong( ldb ), llong( size_B ), Bnorm,
340+
llong( Cm ), llong( Cn ), llong( ldc ), llong( size_C ), Cnorm );
341+
}
342+
if (verbose >= 2) {
343+
printf( "alpha = %.4e + %.4ei; beta = %.4e + %.4ei;\n",
344+
blas::real(alpha), blas::imag(alpha),
345+
blas::real(beta), blas::imag(beta) );
346+
printf( "A = " ); print_matrix( Am, An, A_hi, lda );
347+
printf( "B = " ); print_matrix( Bm, Bn, B_hi, ldb );
348+
printf( "C = " ); print_matrix( Cm, Cn, C_hi, ldc );
349+
}
350+
351+
// run test
352+
testsweeper::flush_cache( params.cache() );
353+
double time = get_wtime();
354+
blas::gemm( layout, transA, transB, m, n, k,
355+
alpha, dA_lo, lda, dB_lo, ldb, beta, dC_lo, ldc, queue );
356+
queue.sync();
357+
time = get_wtime() - time;
358+
359+
double gflop = blas::Gflop< scalar_hi >::gemm( m, n, k );
360+
params.time() = time;
361+
params.gflops() = gflop / time;
362+
363+
blas::copy_matrix( Cm, Cn, dC_lo, ldc, dC_hi, ldc, queue );
364+
blas::device_copy_matrix(Cm, Cn, dC_hi, ldc, C_hi, ldc, queue);
365+
queue.sync();
366+
367+
if (verbose >= 2) {
368+
printf( "C2 = " ); print_matrix( Cm, Cn, C_hi, ldc );
369+
}
370+
371+
if (params.ref() == 'y' || params.check() == 'y') {
372+
// run reference
373+
testsweeper::flush_cache( params.cache() );
374+
time = get_wtime();
375+
cblas_gemm( cblas_layout_const(layout),
376+
cblas_trans_const(transA),
377+
cblas_trans_const(transB),
378+
m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, Cref, ldc ); // keep it like this as it defines the reference
379+
time = get_wtime() - time;
380+
381+
params.ref_time() = time;
382+
params.ref_gflops() = gflop / time;
383+
384+
if (verbose >= 2) {
385+
printf( "Cref = " ); print_matrix( Cm, Cn, Cref, ldc );
386+
}
387+
388+
// check error compared to reference
389+
real_t error;
390+
bool okay;
391+
check_gemm<float, half>( Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm,
392+
Cref, ldc, C_hi, ldc, verbose, &error, &okay );
393+
params.error() = error;
394+
params.okay() = okay;
395+
}
396+
397+
delete[] A_hi;
398+
delete[] B_hi;
399+
delete[] C_hi;
400+
delete[] A_lo;
401+
delete[] B_lo;
402+
delete[] C_lo;
403+
delete[] Cref;
404+
405+
blas::device_free( dA_hi, queue );
406+
blas::device_free( dB_hi, queue );
407+
blas::device_free( dC_hi, queue );
408+
blas::device_free( dA_lo, queue );
409+
blas::device_free( dB_lo, queue );
410+
blas::device_free( dC_lo, queue );
411+
}
200412

201413
// -----------------------------------------------------------------------------
202414
void test_gemm_device( Params& params, bool run )
203415
{
204416
switch (params.datatype()) {
417+
case testsweeper::DataType::Half:
418+
test_gemm_device_work< half, half, half >( params, run );
419+
break;
420+
205421
case testsweeper::DataType::Single:
206422
test_gemm_device_work< float, float, float >( params, run );
207423
break;

0 commit comments

Comments
 (0)