Skip to content

Commit ffe77fa

Browse files
authored
Merge pull request #34 from amcamd/add-geam
adding geam code and tests
2 parents 7e1983e + 69c6376 commit ffe77fa

File tree

9 files changed

+436
-15
lines changed

9 files changed

+436
-15
lines changed

clients/common/hipblas_template_specialization.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,5 +526,30 @@
526526
}
527527
*/
528528

529+
template<>
530+
hipblasStatus_t hipblasGeam<float>(hipblasHandle_t handle,
531+
hipblasOperation_t transA, hipblasOperation_t transB,
532+
int m, int n,
533+
const float *alpha,
534+
const float *A, int lda,
535+
const float *beta,
536+
const float *B, int ldb,
537+
float *C, int ldc){
538+
return hipblasSgeam(handle, transA, transB, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
539+
}
540+
541+
template<>
542+
hipblasStatus_t hipblasGeam<double>(hipblasHandle_t handle,
543+
hipblasOperation_t transA, hipblasOperation_t transB,
544+
int m, int n,
545+
const double *alpha,
546+
const double *A, int lda,
547+
const double *beta,
548+
const double *B, int ldb,
549+
double *C, int ldc){
550+
return hipblasDgeam(handle, transA, transB, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
551+
}
552+
553+
529554

530555

clients/gtest/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set(hipblas_test_source
4646
ger_gtest.cpp
4747
gemm_gtest.cpp
4848
gemm_strided_batched_gtest.cpp
49+
geam_gtest.cpp
4950
)
5051

5152
set( hipblas_benchmark_common

clients/gtest/geam_gtest.cpp

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/* ************************************************************************
2+
* Copyright 2016 Advanced Micro Devices, Inc.
3+
* ************************************************************************ */
4+
5+
#include <gtest/gtest.h>
6+
#include <math.h>
7+
#include <stdexcept>
8+
#include <vector>
9+
#include "testing_geam.hpp"
10+
#include "utility.h"
11+
12+
using ::testing::TestWithParam;
13+
using ::testing::Values;
14+
using ::testing::ValuesIn;
15+
using ::testing::Combine;
16+
using namespace std;
17+
18+
//only GCC/VS 2010 comes with std::tr1::tuple, but it is unnecessary, std::tuple is good enough;
19+
20+
typedef std::tuple<vector<int>, vector<double>, vector<char>> geam_tuple;
21+
22+
/* =====================================================================
23+
README: This file contains testers to verify the correctness of
24+
BLAS routines with google test
25+
26+
It is supposed to be played/used by advance / expert users
27+
Normal users only need to get the library routines without testers
28+
=================================================================== */
29+
30+
//vector of vector, each vector is a {M, N, lda, ldb, ldc};
31+
//add/delete as a group
32+
const
33+
vector<vector<int>> matrix_size_range = {
34+
// { -1, -1, -1, 1, 1},
35+
{ 5, 5, 5, 5, 5},
36+
{ 3, 33, 33, 34, 35},
37+
{ 10, 10, 100, 10, 10},
38+
// { 600, 500, 500, 600, 500},
39+
// {1024, 1024, 1024, 1024, 1024}
40+
};
41+
42+
//vector of vector, each pair is a {alpha, beta};
43+
//add/delete this list in pairs, like {2.0, 4.0}
44+
const
45+
vector<vector<double>> alpha_beta_range = { {1.4, 0.0},
46+
{3.1, 0.3},
47+
{0.0, 1.3},
48+
{0.0, 0.0},
49+
};
50+
51+
52+
//vector of vector, each pair is a {transA, transB};
53+
//add/delete this list in pairs, like {'N', 'T'}
54+
//for single/double precision, 'C'(conjTranspose) will downgraded to 'T' (transpose) internally in sgeam/dgeam,
55+
const
56+
vector<vector<char>> transA_transB_range = {
57+
{'N', 'N'},
58+
{'N', 'T'},
59+
{'C', 'N'},
60+
{'T', 'C'}
61+
};
62+
63+
64+
/* ===============Google Unit Test==================================================== */
65+
66+
67+
/* =====================================================================
68+
BLAS-3 GEAM:
69+
=================================================================== */
70+
71+
/* ============================Setup Arguments======================================= */
72+
73+
//Please use "class Arguments" (see utility.hpp) to pass parameters to templated testers;
74+
//Some routines may not touch/use certain "members" of objects "argus".
75+
//like BLAS-1 Scal does not have lda, BLAS-2 GEMV does not have ldb, ldc;
76+
//That is fine. These testers & routines will leave untouched members alone.
77+
//Do not use std::tuple to directly pass parameters to testers
78+
//by std:tuple, you have unpack it with extreme care for each one by like "std::get<0>" which is not intuitive and error-prone
79+
80+
Arguments setup_geam_arguments(geam_tuple tup)
81+
{
82+
83+
vector<int> matrix_size = std::get<0>(tup);
84+
vector<double> alpha_beta = std::get<1>(tup);
85+
vector<char> transA_transB = std::get<2>(tup);
86+
87+
Arguments arg;
88+
89+
// see the comments about matrix_size_range above
90+
arg.M = matrix_size[0];
91+
arg.N = matrix_size[1];
92+
arg.lda = matrix_size[3];
93+
arg.ldb = matrix_size[4];
94+
arg.ldc = matrix_size[5];
95+
96+
//the first element of alpha_beta_range is always alpha, and the second is always beta
97+
arg.alpha = alpha_beta[0];
98+
arg.beta = alpha_beta[1];
99+
100+
arg.transA_option = transA_transB[0];
101+
arg.transB_option = transA_transB[1];
102+
103+
arg.timing = 0;
104+
105+
return arg;
106+
}
107+
108+
109+
class geam_gtest: public :: TestWithParam <geam_tuple>
110+
{
111+
protected:
112+
geam_gtest(){}
113+
virtual ~geam_gtest(){}
114+
virtual void SetUp(){}
115+
virtual void TearDown(){}
116+
};
117+
118+
119+
TEST_P(geam_gtest, geam_gtest_float)
120+
{
121+
// GetParam return a tuple. Tee setup routine unpack the tuple
122+
// and initializes arg(Arguments) which will be passed to testing routine
123+
// The Arguments data struture have physical meaning associated.
124+
// while the tuple is non-intuitive.
125+
126+
127+
Arguments arg = setup_geam_arguments( GetParam() );
128+
129+
hipblasStatus_t status = testing_geam<float>( arg );
130+
131+
// if not success, then the input argument is problematic, so detect the error message
132+
if(status != HIPBLAS_STATUS_SUCCESS)
133+
{
134+
if( arg.M < 0 || arg.N < 0 )
135+
{
136+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
137+
}
138+
else if(arg.transA_option == 'N' ? arg.lda < arg.M : arg.lda < arg.K)
139+
{
140+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
141+
}
142+
else if(arg.transB_option == 'N' ? arg.ldb < arg.K : arg.ldb < arg.N)
143+
{
144+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
145+
}
146+
else if(arg.ldc < arg.M)
147+
{
148+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
149+
}
150+
}
151+
}
152+
153+
154+
TEST_P(geam_gtest, geam_gtest_double)
155+
{
156+
// GetParam return a tuple. Tee setup routine unpack the tuple
157+
// and initializes arg(Arguments) which will be passed to testing routine
158+
// The Arguments data struture have physical meaning associated.
159+
// while the tuple is non-intuitive.
160+
161+
162+
Arguments arg = setup_geam_arguments( GetParam() );
163+
164+
hipblasStatus_t status = testing_geam<double>( arg );
165+
166+
// if not success, then the input argument is problematic, so detect the error message
167+
if(status != HIPBLAS_STATUS_SUCCESS)
168+
{
169+
if (arg.M < 0 || arg.N < 0)
170+
{
171+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
172+
}
173+
else if (arg.transA_option == 'N' ? arg.lda < arg.M : arg.lda < arg.K)
174+
{
175+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
176+
}
177+
else if (arg.transB_option == 'N' ? arg.ldb < arg.K : arg.ldb < arg.N)
178+
{
179+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
180+
}
181+
else if (arg.ldc < arg.M)
182+
{
183+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
184+
}
185+
}
186+
}
187+
188+
//THis function mainly test the scope of alpha_beta, transA_transB,.the scope of matrix_size_range is small
189+
190+
INSTANTIATE_TEST_CASE_P(hipblasGeam_scalar_transpose, geam_gtest,
191+
Combine(
192+
ValuesIn(matrix_size_range),
193+
ValuesIn(alpha_beta_range),
194+
ValuesIn(transA_transB_range)
195+
)
196+
);

clients/include/hipblas.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,5 +182,16 @@
182182
T *invA);
183183

184184

185+
template<typename T>
186+
hipblasStatus_t hipblasGeam(hipblasHandle_t handle,
187+
hipblasOperation_t transA, hipblasOperation_t transB,
188+
int m, int n,
189+
const T *alpha,
190+
const T *A, int lda,
191+
const T *beta,
192+
const T *B, int ldb,
193+
T *C, int ldc);
194+
195+
185196

186197
#endif // _ROCBLAS_HPP_

0 commit comments

Comments
 (0)