-
Notifications
You must be signed in to change notification settings - Fork 28
Closed
Description
The fully templated GEMM code (below) references the input value of C(i,j)
even when beta == 0.0
. As a result, if C
isn't initialized and contains a NaN, then that NaN can still be propagated through to the output when beta == 0.0
.
Lines 166 to 265 in e954a9b
// alpha != zero | |
if (transA == Op::NoTrans) { | |
if (transB == Op::NoTrans) { | |
for (int64_t j = 0; j < n; ++j) { | |
for (int64_t i = 0; i < m; ++i) | |
C(i, j) *= beta; | |
for (int64_t l = 0; l < k; ++l) { | |
scalar_t alpha_Blj = alpha*B(l, j); | |
for (int64_t i = 0; i < m; ++i) | |
C(i, j) += A(i, l)*alpha_Blj; | |
} | |
} | |
} | |
else if (transB == Op::Trans) { | |
for (int64_t j = 0; j < n; ++j) { | |
for (int64_t i = 0; i < m; ++i) | |
C(i, j) *= beta; | |
for (int64_t l = 0; l < k; ++l) { | |
scalar_t alpha_Bjl = alpha*B(j, l); | |
for (int64_t i = 0; i < m; ++i) | |
C(i, j) += A(i, l)*alpha_Bjl; | |
} | |
} | |
} | |
else { // transB == Op::ConjTrans | |
for (int64_t j = 0; j < n; ++j) { | |
for (int64_t i = 0; i < m; ++i) | |
C(i, j) *= beta; | |
for (int64_t l = 0; l < k; ++l) { | |
scalar_t alpha_Bjl = alpha*conj(B(j, l)); | |
for (int64_t i = 0; i < m; ++i) | |
C(i, j) += A(i, l)*alpha_Bjl; | |
} | |
} | |
} | |
} | |
else if (transA == Op::Trans) { | |
if (transB == Op::NoTrans) { | |
for (int64_t j = 0; j < n; ++j) { | |
for (int64_t i = 0; i < m; ++i) { | |
scalar_t sum = zero; | |
for (int64_t l = 0; l < k; ++l) | |
sum += A(l, i)*B(l, j); | |
C(i, j) = alpha*sum + beta*C(i, j); | |
} | |
} | |
} | |
else if (transB == Op::Trans) { | |
for (int64_t j = 0; j < n; ++j) { | |
for (int64_t i = 0; i < m; ++i) { | |
scalar_t sum = zero; | |
for (int64_t l = 0; l < k; ++l) | |
sum += A(l, i)*B(j, l); | |
C(i, j) = alpha*sum + beta*C(i, j); | |
} | |
} | |
} | |
else { // transB == Op::ConjTrans | |
for (int64_t j = 0; j < n; ++j) { | |
for (int64_t i = 0; i < m; ++i) { | |
scalar_t sum = zero; | |
for (int64_t l = 0; l < k; ++l) | |
sum += A(l, i)*conj(B(j, l)); | |
C(i, j) = alpha*sum + beta*C(i, j); | |
} | |
} | |
} | |
} | |
else { // transA == Op::ConjTrans | |
if (transB == Op::NoTrans) { | |
for (int64_t j = 0; j < n; ++j) { | |
for (int64_t i = 0; i < m; ++i) { | |
scalar_t sum = zero; | |
for (int64_t l = 0; l < k; ++l) | |
sum += conj(A(l, i))*B(l, j); | |
C(i, j) = alpha*sum + beta*C(i, j); | |
} | |
} | |
} | |
else if (transB == Op::Trans) { | |
for (int64_t j = 0; j < n; ++j) { | |
for (int64_t i = 0; i < m; ++i) { | |
scalar_t sum = zero; | |
for (int64_t l = 0; l < k; ++l) | |
sum += conj(A(l, i))*B(j, l); | |
C(i, j) = alpha*sum + beta*C(i, j); | |
} | |
} | |
} | |
else { // transB == Op::ConjTrans | |
for (int64_t j = 0; j < n; ++j) { | |
for (int64_t i = 0; i < m; ++i) { | |
scalar_t sum = zero; | |
for (int64_t l = 0; l < k; ++l) | |
sum += A(l, i)*B(j, l); // little improvement here | |
C(i, j) = alpha*conj(sum) + beta*C(i, j); | |
} | |
} | |
} | |
} |
@weslleyspereira, recently I suggested that tests for this situation be added to your BLAS stress tests. Once you add such tests, it would be good to run against BLAS++. It might be that other fully-templated functions in BLAS++ make similar mistakes.
weslleyspereira
Metadata
Metadata
Assignees
Labels
No labels