|
33 | 33 | #include "testing_gemm_strided_batched_kernel_name.hpp" |
34 | 34 | #include "testing_trsm.hpp" |
35 | 35 | #include "testing_gemm_ex.hpp" |
| 36 | +#include "testing_gemm_strided_batched_ex.hpp" |
36 | 37 | #endif |
37 | 38 |
|
38 | 39 | namespace po = boost::program_options; |
@@ -114,7 +115,7 @@ int main(int argc, char* argv[]) |
114 | 115 | "BLAS-2 and BLAS-3: second dimension * leading dimension.") |
115 | 116 |
|
116 | 117 | ("stride_d", |
117 | | - po::value<rocblas_int>(&argus.stride_c)->default_value(128*128), |
| 118 | + po::value<rocblas_int>(&argus.stride_d)->default_value(128*128), |
118 | 119 | "Specific stride of strided_batched matrix D, is only applicable to strided batched" |
119 | 120 | "BLAS_EX: second dimension * leading dimension.") |
120 | 121 |
|
@@ -500,6 +501,38 @@ int main(int argc, char* argv[]) |
500 | 501 | else if(precision == 'd') |
501 | 502 | testing_gemm_strided_batched<double>(argus); |
502 | 503 | } |
| 504 | + else if(function == "gemm_strided_batched_ex") |
| 505 | + { |
| 506 | + // adjust dimension for GEMM routines |
| 507 | + rocblas_int min_lda = argus.transA_option == 'N' ? argus.M : argus.K; |
| 508 | + rocblas_int min_ldb = argus.transB_option == 'N' ? argus.K : argus.N; |
| 509 | + rocblas_int min_ldc = argus.M; |
| 510 | + if(argus.lda < min_lda) |
| 511 | + { |
| 512 | + std::cout << "rocblas-bench INFO: lda < min_lda, set lda = " << min_lda << std::endl; |
| 513 | + argus.lda = min_lda; |
| 514 | + } |
| 515 | + if(argus.ldb < min_ldb) |
| 516 | + { |
| 517 | + std::cout << "rocblas-bench INFO: ldb < min_ldb, set ldb = " << min_ldb << std::endl; |
| 518 | + argus.ldb = min_ldb; |
| 519 | + } |
| 520 | + if(argus.ldc < min_ldc) |
| 521 | + { |
| 522 | + std::cout << "rocblas-bench INFO: ldc < min_ldc, set ldc = " << min_ldc << std::endl; |
| 523 | + argus.ldc = min_ldc; |
| 524 | + } |
| 525 | + |
| 526 | + rocblas_int min_stride_c = argus.ldc * argus.N; |
| 527 | + if(argus.stride_c < min_stride_c) |
| 528 | + { |
| 529 | + std::cout << "rocblas-bench INFO: stride_c < min_stride_c, set stride_c = " |
| 530 | + << min_stride_c << std::endl; |
| 531 | + argus.stride_c = min_stride_c; |
| 532 | + } |
| 533 | + |
| 534 | + testing_gemm_strided_batched_ex(argus); |
| 535 | + } |
503 | 536 | else if(function == "gemm_kernel_name") |
504 | 537 | { |
505 | 538 | // adjust dimension for GEMM routines |
|
0 commit comments