Skip to content

Commit 50b865f

Browse files
authored
Hotfix: dgmm test fix (#247)
1 parent 5299dc5 commit 50b865f

14 files changed

+69
-57
lines changed

clients/gtest/dgmm_gtest.cpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,22 @@ TEST_P(dgmm_gtest, dgmm_gtest_float)
122122
// The Arguments data struture have physical meaning associated.
123123
// while the tuple is non-intuitive.
124124

125-
// Arguments arg = setup_dgmm_arguments(GetParam());
126-
127-
// hipblasStatus_t status = testing_dgmm<float>(arg);
128-
129-
// // if not success, then the input argument is problematic, so detect the error message
130-
// if(status != HIPBLAS_STATUS_SUCCESS)
131-
// {
132-
// if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.ldc < arg.M || arg.incx == 0)
133-
// {
134-
// EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
135-
// }
136-
// else
137-
// {
138-
// EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail
139-
// }
140-
// }
125+
Arguments arg = setup_dgmm_arguments(GetParam());
126+
127+
hipblasStatus_t status = testing_dgmm<float>(arg);
128+
129+
// if not success, then the input argument is problematic, so detect the error message
130+
if(status != HIPBLAS_STATUS_SUCCESS)
131+
{
132+
if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.ldc < arg.M)
133+
{
134+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
135+
}
136+
else
137+
{
138+
EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail
139+
}
140+
}
141141
}
142142

143143
TEST_P(dgmm_gtest, dgmm_gtest_float_complex)
@@ -147,22 +147,22 @@ TEST_P(dgmm_gtest, dgmm_gtest_float_complex)
147147
// The Arguments data struture have physical meaning associated.
148148
// while the tuple is non-intuitive.
149149

150-
// Arguments arg = setup_dgmm_arguments(GetParam());
151-
152-
// hipblasStatus_t status = testing_dgmm<hipblasComplex>(arg);
153-
154-
// // if not success, then the input argument is problematic, so detect the error message
155-
// if(status != HIPBLAS_STATUS_SUCCESS)
156-
// {
157-
// if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.ldc < arg.M || arg.incx == 0)
158-
// {
159-
// EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
160-
// }
161-
// else
162-
// {
163-
// EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail
164-
// }
165-
// }
150+
Arguments arg = setup_dgmm_arguments(GetParam());
151+
152+
hipblasStatus_t status = testing_dgmm<hipblasComplex>(arg);
153+
154+
// if not success, then the input argument is problematic, so detect the error message
155+
if(status != HIPBLAS_STATUS_SUCCESS)
156+
{
157+
if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.ldc < arg.M)
158+
{
159+
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
160+
}
161+
else
162+
{
163+
EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail
164+
}
165+
}
166166
}
167167

168168
TEST_P(dgmm_gtest, dgmm_batched_gtest_float)

clients/include/hipblas_vector.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class device_batch_vector : private d_vector<T, PAD, U>
131131
return data[n];
132132
}
133133

134-
operator T**()
134+
operator T* *()
135135
{
136136
return data;
137137
}

clients/include/testing_dgmm.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ hipblasStatus_t testing_dgmm(Arguments argus)
3737
int C_size = size_t(ldc) * N;
3838
int k = (side == HIPBLAS_SIDE_RIGHT ? N : M);
3939
int X_size = size_t(incx) * k;
40+
if(!X_size)
41+
X_size = 1;
4042

4143
hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS;
4244

4345
// argument sanity check, quick return if input parameters are invalid before allocating invalid
4446
// memory
45-
if(M < 0 || N < 0 || lda < M || ldc < M || incx == 0)
47+
if(M < 0 || N < 0 || lda < M || ldc < M)
4648
{
4749
status = HIPBLAS_STATUS_INVALID_VALUE;
4850
return status;
@@ -110,11 +112,11 @@ hipblasStatus_t testing_dgmm(Arguments argus)
110112
{
111113
if(HIPBLAS_SIDE_RIGHT == side)
112114
{
113-
hC_gold[i1 + i2 * ldc] = hA_copy[i1 + i2 * lda] + hx_copy[i2 * incx];
115+
hC_gold[i1 + i2 * ldc] = hA_copy[i1 + i2 * lda] * hx_copy[i2 * incx];
114116
}
115117
else
116118
{
117-
hC_gold[i1 + i2 * ldc] = hA_copy[i1 + i2 * lda] + hx_copy[i1 * incx];
119+
hC_gold[i1 + i2 * ldc] = hA_copy[i1 + i2 * lda] * hx_copy[i1 * incx];
118120
}
119121
}
120122
}

clients/include/testing_dgmm_batched.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ hipblasStatus_t testing_dgmm_batched(Arguments argus)
3939
int C_size = size_t(ldc) * N;
4040
int k = (side == HIPBLAS_SIDE_RIGHT ? N : M);
4141
int X_size = size_t(incx) * k;
42+
if(!X_size)
43+
X_size = 1;
4244

4345
hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS;
4446

4547
// argument sanity check, quick return if input parameters are invalid before allocating invalid
4648
// memory
47-
if(M < 0 || N < 0 || lda < M || ldc < M || incx == 0 || batch_count < 0)
49+
if(M < 0 || N < 0 || lda < M || ldc < M || batch_count < 0)
4850
{
4951
status = HIPBLAS_STATUS_INVALID_VALUE;
5052
return status;
@@ -142,12 +144,12 @@ hipblasStatus_t testing_dgmm_batched(Arguments argus)
142144
if(HIPBLAS_SIDE_RIGHT == side)
143145
{
144146
hC_gold[b][i1 + i2 * ldc]
145-
= hA_copy[b][i1 + i2 * lda] + hx_copy[b][i2 * incx];
147+
= hA_copy[b][i1 + i2 * lda] * hx_copy[b][i2 * incx];
146148
}
147149
else
148150
{
149151
hC_gold[b][i1 + i2 * ldc]
150-
= hA_copy[b][i1 + i2 * lda] + hx_copy[b][i1 * incx];
152+
= hA_copy[b][i1 + i2 * lda] * hx_copy[b][i1 * incx];
151153
}
152154
}
153155
}

clients/include/testing_dgmm_strided_batched.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ hipblasStatus_t testing_dgmm_strided_batched(Arguments argus)
4040
int stride_A = size_t(lda) * N * stride_scale;
4141
int stride_x = size_t(incx) * k * stride_scale;
4242
int stride_C = size_t(ldc) * N * stride_scale;
43+
if(!stride_x)
44+
stride_x = 1;
4345

4446
int A_size = stride_A * batch_count;
4547
int C_size = stride_C * batch_count;
@@ -49,7 +51,7 @@ hipblasStatus_t testing_dgmm_strided_batched(Arguments argus)
4951

5052
// argument sanity check, quick return if input parameters are invalid before allocating invalid
5153
// memory
52-
if(M < 0 || N < 0 || lda < M || ldc < M || incx == 0 || batch_count < 0)
54+
if(M < 0 || N < 0 || lda < M || ldc < M || batch_count < 0)
5355
{
5456
status = HIPBLAS_STATUS_INVALID_VALUE;
5557
return status;
@@ -123,11 +125,11 @@ hipblasStatus_t testing_dgmm_strided_batched(Arguments argus)
123125
{
124126
if(HIPBLAS_SIDE_RIGHT == side)
125127
{
126-
hC_goldb[i1 + i2 * ldc] = hA_copyb[i1 + i2 * lda] + hx_copyb[i2 * incx];
128+
hC_goldb[i1 + i2 * ldc] = hA_copyb[i1 + i2 * lda] * hx_copyb[i2 * incx];
127129
}
128130
else
129131
{
130-
hC_goldb[i1 + i2 * ldc] = hA_copyb[i1 + i2 * lda] + hx_copyb[i1 * incx];
132+
hC_goldb[i1 + i2 * ldc] = hA_copyb[i1 + i2 * lda] * hx_copyb[i1 * incx];
131133
}
132134
}
133135
}

clients/include/testing_geqrf.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using namespace std;
2020
template <typename T, typename U>
2121
hipblasStatus_t testing_geqrf(Arguments argus)
2222
{
23-
bool FORTRAN = argus.fortran;
23+
bool FORTRAN = argus.fortran;
2424
auto hipblasGeqrfFn = FORTRAN ? hipblasGeqrf<T, true> : hipblasGeqrf<T, false>;
2525

2626
int M = argus.M;

clients/include/testing_geqrf_batched.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ using namespace std;
2020
template <typename T, typename U>
2121
hipblasStatus_t testing_geqrf_batched(Arguments argus)
2222
{
23-
bool FORTRAN = argus.fortran;
24-
auto hipblasGeqrfBatchedFn = FORTRAN ? hipblasGeqrfBatched<T, true> : hipblasGeqrfBatched<T, false>;
23+
bool FORTRAN = argus.fortran;
24+
auto hipblasGeqrfBatchedFn
25+
= FORTRAN ? hipblasGeqrfBatched<T, true> : hipblasGeqrfBatched<T, false>;
2526

2627
int M = argus.M;
2728
int N = argus.N;

clients/include/testing_geqrf_strided_batched.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ using namespace std;
2020
template <typename T, typename U>
2121
hipblasStatus_t testing_geqrf_strided_batched(Arguments argus)
2222
{
23-
bool FORTRAN = argus.fortran;
24-
auto hipblasGeqrfStridedBatchedFn = FORTRAN ? hipblasGeqrfStridedBatched<T, true> : hipblasGeqrfStridedBatched<T, false>;
23+
bool FORTRAN = argus.fortran;
24+
auto hipblasGeqrfStridedBatchedFn
25+
= FORTRAN ? hipblasGeqrfStridedBatched<T, true> : hipblasGeqrfStridedBatched<T, false>;
2526

2627
int M = argus.M;
2728
int N = argus.N;

clients/include/testing_getrf.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using namespace std;
2020
template <typename T, typename U>
2121
hipblasStatus_t testing_getrf(Arguments argus)
2222
{
23-
bool FORTRAN = argus.fortran;
23+
bool FORTRAN = argus.fortran;
2424
auto hipblasGetrfFn = FORTRAN ? hipblasGetrf<T, true> : hipblasGetrf<T, false>;
2525

2626
int M = argus.N;

clients/include/testing_getrf_batched.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ using namespace std;
2020
template <typename T, typename U>
2121
hipblasStatus_t testing_getrf_batched(Arguments argus)
2222
{
23-
bool FORTRAN = argus.fortran;
24-
auto hipblasGetrfBatchedFn = FORTRAN ? hipblasGetrfBatched<T, true> : hipblasGetrfBatched<T, false>;
23+
bool FORTRAN = argus.fortran;
24+
auto hipblasGetrfBatchedFn
25+
= FORTRAN ? hipblasGetrfBatched<T, true> : hipblasGetrfBatched<T, false>;
2526

2627
int M = argus.N;
2728
int N = argus.N;

0 commit comments

Comments
 (0)