|
10 | 10 | #include "print_matrix.hh"
|
11 | 11 | #include "check_gemm.hh"
|
12 | 12 |
|
| 13 | +#include "utils.cuh" |
| 14 | + |
13 | 15 | // -----------------------------------------------------------------------------
|
14 | 16 | template <typename TA, typename TB, typename TC>
|
15 | 17 | void test_gemm_device_work( Params& params, bool run )
|
@@ -197,11 +199,225 @@ void test_gemm_device_work( Params& params, bool run )
|
197 | 199 | blas::device_free( dB, queue );
|
198 | 200 | blas::device_free( dC, queue );
|
199 | 201 | }
|
| 202 | +// |
| 203 | +// ----------------------------------------------------------------------------- |
| 204 | +template <> |
| 205 | +void test_gemm_device_work<half,half,half>( Params& params, bool run ) |
| 206 | +{ |
| 207 | + using namespace testsweeper; |
| 208 | + using std::real; |
| 209 | + using std::imag; |
| 210 | + using blas::Op; |
| 211 | + using blas::Layout; |
| 212 | + using scalar_hi = float; |
| 213 | + using scalar_lo = half; |
| 214 | + using real_t = blas::real_type< scalar_hi >; |
| 215 | + |
| 216 | + // get & mark input values |
| 217 | + blas::Layout layout = params.layout(); |
| 218 | + blas::Op transA = params.transA(); |
| 219 | + blas::Op transB = params.transB(); |
| 220 | + scalar_lo alpha = params.alpha(); |
| 221 | + scalar_lo beta = params.beta(); |
| 222 | + int64_t m = params.dim.m(); |
| 223 | + int64_t n = params.dim.n(); |
| 224 | + int64_t k = params.dim.k(); |
| 225 | + int64_t device = params.device(); |
| 226 | + int64_t align = params.align(); |
| 227 | + int64_t verbose = params.verbose(); |
| 228 | + |
| 229 | + // mark non-standard output values |
| 230 | + params.gflops(); |
| 231 | + params.ref_time(); |
| 232 | + params.ref_gflops(); |
| 233 | + |
| 234 | + if (! run) |
| 235 | + return; |
| 236 | + |
| 237 | + if (blas::get_device_count() == 0) { |
| 238 | + params.msg() = "skipping: no GPU devices or no GPU support"; |
| 239 | + return; |
| 240 | + } |
| 241 | + |
| 242 | + // setup |
| 243 | + int64_t Am = (transA == Op::NoTrans ? m : k); |
| 244 | + int64_t An = (transA == Op::NoTrans ? k : m); |
| 245 | + int64_t Bm = (transB == Op::NoTrans ? k : n); |
| 246 | + int64_t Bn = (transB == Op::NoTrans ? n : k); |
| 247 | + int64_t Cm = m; |
| 248 | + int64_t Cn = n; |
| 249 | + if (layout == Layout::RowMajor) { |
| 250 | + std::swap( Am, An ); |
| 251 | + std::swap( Bm, Bn ); |
| 252 | + std::swap( Cm, Cn ); |
| 253 | + } |
| 254 | + int64_t lda = roundup( Am, align ); |
| 255 | + int64_t ldb = roundup( Bm, align ); |
| 256 | + int64_t ldc = roundup( Cm, align ); |
| 257 | + size_t size_A = size_t(lda)*An; |
| 258 | + size_t size_B = size_t(ldb)*Bn; |
| 259 | + size_t size_C = size_t(ldc)*Cn; |
| 260 | + half* A_lo = new half[ size_A ]; |
| 261 | + half* B_lo = new half[ size_B ]; |
| 262 | + half* C_lo = new half[ size_C ]; |
| 263 | + float* A_hi = new float[ size_A ]; |
| 264 | + float* B_hi = new float[ size_B ]; |
| 265 | + float* C_hi = new float[ size_C ]; |
| 266 | + float* Cref = new float[ size_C ]; |
| 267 | + |
| 268 | + // device specifics |
| 269 | + blas::Queue queue( device ); |
| 270 | + half* dA_lo; |
| 271 | + half* dB_lo; |
| 272 | + half* dC_lo; |
| 273 | + float* dA_hi; |
| 274 | + float* dB_hi; |
| 275 | + float* dC_hi; |
| 276 | + |
| 277 | + dA_lo = blas::device_malloc<half>( size_A, queue ); |
| 278 | + dB_lo = blas::device_malloc<half>( size_B, queue ); |
| 279 | + dC_lo = blas::device_malloc<half>( size_C, queue ); |
| 280 | + dA_hi = blas::device_malloc<float>( size_A, queue ); |
| 281 | + dB_hi = blas::device_malloc<float>( size_B, queue ); |
| 282 | + dC_hi = blas::device_malloc<float>( size_C, queue ); |
| 283 | + |
| 284 | + int64_t idist = 1; |
| 285 | + int iseed[4] = { 0, 0, 0, 1 }; |
| 286 | + lapack_larnv( idist, iseed, size_A, A_hi ); |
| 287 | + lapack_larnv( idist, iseed, size_B, B_hi ); |
| 288 | + lapack_larnv( idist, iseed, size_C, C_hi ); |
| 289 | + lapack_lacpy( "g", Cm, Cn, C_hi, ldc, Cref, ldc ); |
| 290 | + |
| 291 | + blas::device_copy_matrix(Am, An, A_hi, lda, dA_hi, lda, queue); |
| 292 | + blas::device_copy_matrix(Bm, Bn, B_hi, ldb, dB_hi, ldb, queue); |
| 293 | + blas::device_copy_matrix(Cm, Cn, C_hi, ldc, dC_hi, ldc, queue); |
| 294 | + |
| 295 | + blas::copy_matrix( Am, An, dA_hi, lda, dA_lo, lda, queue ); |
| 296 | + blas::copy_matrix( Bm, Bn, dB_hi, ldb, dB_lo, ldb, queue ); |
| 297 | + blas::copy_matrix( Cm, Cn, dC_hi, ldc, dC_lo, ldc, queue ); |
| 298 | + queue.sync(); |
| 299 | + |
| 300 | + // norms for error check |
| 301 | + real_t work[1]; |
| 302 | + real_t Anorm = lapack_lange( "f", Am, An, A_hi, lda, work ); |
| 303 | + real_t Bnorm = lapack_lange( "f", Bm, Bn, B_hi, ldb, work ); |
| 304 | + real_t Cnorm = lapack_lange( "f", Cm, Cn, C_hi, ldc, work ); |
| 305 | + |
| 306 | + // test error exits |
| 307 | + assert_throw( blas::gemm( Layout(0), transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 308 | + assert_throw( blas::gemm( layout, Op(0), transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 309 | + assert_throw( blas::gemm( layout, transA, Op(0), m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 310 | + assert_throw( blas::gemm( layout, transA, transB, -1, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 311 | + assert_throw( blas::gemm( layout, transA, transB, m, -1, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 312 | + assert_throw( blas::gemm( layout, transA, transB, m, n, -1, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 313 | + |
| 314 | + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 315 | + assert_throw( blas::gemm( Layout::ColMajor, Op::Trans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 316 | + assert_throw( blas::gemm( Layout::ColMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 317 | + |
| 318 | + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, k-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 319 | + assert_throw( blas::gemm( Layout::RowMajor, Op::Trans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 320 | + assert_throw( blas::gemm( Layout::RowMajor, Op::ConjTrans, Op::NoTrans, m, n, k, alpha, dA_hi, m-1, dB_hi, ldb, beta, dC_hi, ldc, queue ), blas::Error ); |
| 321 | + |
| 322 | + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error ); |
| 323 | + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error ); |
| 324 | + assert_throw( blas::gemm( Layout::ColMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error ); |
| 325 | + |
| 326 | + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::NoTrans, m, n, k, alpha, dA_hi, lda, B_hi, n-1, beta, dC_hi, ldc, queue ), blas::Error ); |
| 327 | + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::Trans, m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error ); |
| 328 | + assert_throw( blas::gemm( Layout::RowMajor, Op::NoTrans, Op::ConjTrans, m, n, k, alpha, dA_hi, lda, B_hi, k-1, beta, dC_hi, ldc, queue ), blas::Error ); |
| 329 | + |
| 330 | + assert_throw( blas::gemm( Layout::ColMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, m-1, queue ), blas::Error ); |
| 331 | + assert_throw( blas::gemm( Layout::RowMajor, transA, transB, m, n, k, alpha, dA_hi, lda, dB_hi, ldb, beta, dC_hi, n-1, queue ), blas::Error ); |
| 332 | + |
| 333 | + if (verbose >= 1) { |
| 334 | + printf( "\n" |
| 335 | + "A Am=%5lld, An=%5lld, lda=%5lld, size=%10lld, norm %.2e\n" |
| 336 | + "B Bm=%5lld, Bn=%5lld, ldb=%5lld, size=%10lld, norm %.2e\n" |
| 337 | + "C Cm=%5lld, Cn=%5lld, ldc=%5lld, size=%10lld, norm %.2e\n", |
| 338 | + llong( Am ), llong( An ), llong( lda ), llong( size_A ), Anorm, |
| 339 | + llong( Bm ), llong( Bn ), llong( ldb ), llong( size_B ), Bnorm, |
| 340 | + llong( Cm ), llong( Cn ), llong( ldc ), llong( size_C ), Cnorm ); |
| 341 | + } |
| 342 | + if (verbose >= 2) { |
| 343 | + printf( "alpha = %.4e + %.4ei; beta = %.4e + %.4ei;\n", |
| 344 | + blas::real(alpha), blas::imag(alpha), |
| 345 | + blas::real(beta), blas::imag(beta) ); |
| 346 | + printf( "A = " ); print_matrix( Am, An, A_hi, lda ); |
| 347 | + printf( "B = " ); print_matrix( Bm, Bn, B_hi, ldb ); |
| 348 | + printf( "C = " ); print_matrix( Cm, Cn, C_hi, ldc ); |
| 349 | + } |
| 350 | + |
| 351 | + // run test |
| 352 | + testsweeper::flush_cache( params.cache() ); |
| 353 | + double time = get_wtime(); |
| 354 | + blas::gemm( layout, transA, transB, m, n, k, |
| 355 | + alpha, dA_lo, lda, dB_lo, ldb, beta, dC_lo, ldc, queue ); |
| 356 | + queue.sync(); |
| 357 | + time = get_wtime() - time; |
| 358 | + |
| 359 | + double gflop = blas::Gflop< scalar_hi >::gemm( m, n, k ); |
| 360 | + params.time() = time; |
| 361 | + params.gflops() = gflop / time; |
| 362 | + |
| 363 | + blas::copy_matrix( Cm, Cn, dC_lo, ldc, dC_hi, ldc, queue ); |
| 364 | + blas::device_copy_matrix(Cm, Cn, dC_hi, ldc, C_hi, ldc, queue); |
| 365 | + queue.sync(); |
| 366 | + |
| 367 | + if (verbose >= 2) { |
| 368 | + printf( "C2 = " ); print_matrix( Cm, Cn, C_hi, ldc ); |
| 369 | + } |
| 370 | + |
| 371 | + if (params.ref() == 'y' || params.check() == 'y') { |
| 372 | + // run reference |
| 373 | + testsweeper::flush_cache( params.cache() ); |
| 374 | + time = get_wtime(); |
| 375 | + cblas_gemm( cblas_layout_const(layout), |
| 376 | + cblas_trans_const(transA), |
| 377 | + cblas_trans_const(transB), |
| 378 | + m, n, k, alpha, A_hi, lda, B_hi, ldb, beta, Cref, ldc ); // keep it like this as it defines the reference |
| 379 | + time = get_wtime() - time; |
| 380 | + |
| 381 | + params.ref_time() = time; |
| 382 | + params.ref_gflops() = gflop / time; |
| 383 | + |
| 384 | + if (verbose >= 2) { |
| 385 | + printf( "Cref = " ); print_matrix( Cm, Cn, Cref, ldc ); |
| 386 | + } |
| 387 | + |
| 388 | + // check error compared to reference |
| 389 | + real_t error; |
| 390 | + bool okay; |
| 391 | + check_gemm<float, half>( Cm, Cn, k, alpha, beta, Anorm, Bnorm, Cnorm, |
| 392 | + Cref, ldc, C_hi, ldc, verbose, &error, &okay ); |
| 393 | + params.error() = error; |
| 394 | + params.okay() = okay; |
| 395 | + } |
| 396 | + |
| 397 | + delete[] A_hi; |
| 398 | + delete[] B_hi; |
| 399 | + delete[] C_hi; |
| 400 | + delete[] A_lo; |
| 401 | + delete[] B_lo; |
| 402 | + delete[] C_lo; |
| 403 | + delete[] Cref; |
| 404 | + |
| 405 | + blas::device_free( dA_hi, queue ); |
| 406 | + blas::device_free( dB_hi, queue ); |
| 407 | + blas::device_free( dC_hi, queue ); |
| 408 | + blas::device_free( dA_lo, queue ); |
| 409 | + blas::device_free( dB_lo, queue ); |
| 410 | + blas::device_free( dC_lo, queue ); |
| 411 | +} |
200 | 412 |
|
201 | 413 | // -----------------------------------------------------------------------------
|
202 | 414 | void test_gemm_device( Params& params, bool run )
|
203 | 415 | {
|
204 | 416 | switch (params.datatype()) {
|
| 417 | + case testsweeper::DataType::Half: |
| 418 | + test_gemm_device_work< half, half, half >( params, run ); |
| 419 | + break; |
| 420 | + |
205 | 421 | case testsweeper::DataType::Single:
|
206 | 422 | test_gemm_device_work< float, float, float >( params, run );
|
207 | 423 | break;
|
|
0 commit comments