|
76 | 76 |
|
77 | 77 | #endif |
78 | 78 |
|
79 | | -#define NEAR_ASSERT_HALF(a, b, err) ASSERT_NEAR(float(a), float(b), err) |
| 79 | +#define NEAR_ASSERT_HALF(a, b, err) ASSERT_NEAR(half_to_float(a), half_to_float(b), err) |
| 80 | +#define NEAR_ASSERT_BF16(a, b, err) ASSERT_NEAR(bfloat16_to_float(a), bfloat16_to_float(b), err) |
80 | 81 |
|
81 | 82 | #define NEAR_ASSERT_COMPLEX(a, b, err) \ |
82 | 83 | do \ |
@@ -105,6 +106,13 @@ void near_check_general( |
105 | 106 | NEAR_CHECK(M, N, 1, lda, 0, hCPU, hGPU, abs_error, NEAR_ASSERT_HALF); |
106 | 107 | } |
107 | 108 |
|
| 109 | +template <> |
| 110 | +void near_check_general( |
| 111 | + int M, int N, int lda, hipblasBfloat16* hCPU, hipblasBfloat16* hGPU, double abs_error) |
| 112 | +{ |
| 113 | + NEAR_CHECK(M, N, 1, lda, 0, hCPU, hGPU, abs_error, NEAR_ASSERT_BF16); |
| 114 | +} |
| 115 | + |
108 | 116 | template <> |
109 | 117 | void near_check_general( |
110 | 118 | int M, int N, int lda, hipblasComplex* hCPU, hipblasComplex* hGPU, double abs_error) |
@@ -160,6 +168,19 @@ void near_check_general(int M, |
160 | 168 | NEAR_CHECK(M, N, batch_count, lda, strideA, hCPU, hGPU, abs_error, NEAR_ASSERT_HALF); |
161 | 169 | } |
162 | 170 |
|
| 171 | +template <> |
| 172 | +void near_check_general(int M, |
| 173 | + int N, |
| 174 | + int batch_count, |
| 175 | + int lda, |
| 176 | + hipblasStride strideA, |
| 177 | + hipblasBfloat16* hCPU, |
| 178 | + hipblasBfloat16* hGPU, |
| 179 | + double abs_error) |
| 180 | +{ |
| 181 | + NEAR_CHECK(M, N, batch_count, lda, strideA, hCPU, hGPU, abs_error, NEAR_ASSERT_BF16); |
| 182 | +} |
| 183 | + |
163 | 184 | template <> |
164 | 185 | void near_check_general(int M, |
165 | 186 | int N, |
@@ -200,6 +221,18 @@ void near_check_general(int M, |
200 | 221 | NEAR_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, abs_error, NEAR_ASSERT_HALF); |
201 | 222 | } |
202 | 223 |
|
| 224 | +template <> |
| 225 | +void near_check_general(int M, |
| 226 | + int N, |
| 227 | + int batch_count, |
| 228 | + int lda, |
| 229 | + host_vector<hipblasBfloat16> hCPU[], |
| 230 | + host_vector<hipblasBfloat16> hGPU[], |
| 231 | + double abs_error) |
| 232 | +{ |
| 233 | + NEAR_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, abs_error, NEAR_ASSERT_BF16); |
| 234 | +} |
| 235 | + |
203 | 236 | template <> |
204 | 237 | void near_check_general(int M, |
205 | 238 | int N, |
@@ -262,6 +295,18 @@ void near_check_general(int M, |
262 | 295 | NEAR_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, abs_error, NEAR_ASSERT_HALF); |
263 | 296 | } |
264 | 297 |
|
| 298 | +template <> |
| 299 | +void near_check_general(int M, |
| 300 | + int N, |
| 301 | + int batch_count, |
| 302 | + int lda, |
| 303 | + hipblasBfloat16* hCPU[], |
| 304 | + hipblasBfloat16* hGPU[], |
| 305 | + double abs_error) |
| 306 | +{ |
| 307 | + NEAR_CHECK_B(M, N, batch_count, lda, hCPU, hGPU, abs_error, NEAR_ASSERT_BF16); |
| 308 | +} |
| 309 | + |
265 | 310 | template <> |
266 | 311 | void near_check_general( |
267 | 312 | int M, int N, int batch_count, int lda, float* hCPU[], float* hGPU[], double abs_error) |
|
0 commit comments