11/* ************************************************************************
2- * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved.
2+ * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved.
33 *
44 * Permission is hereby granted, free of charge, to any person obtaining a copy
55 * of this software and associated documentation files (the "Software"), to deal
@@ -1327,14 +1327,20 @@ rocblas_status rocblas_syr2k_her2k_dispatch(rocblas_handle handle,
13271327 return rocblas_status_success;
13281328}
13291329
1330- template <bool copy_from_C_to_W_C, bool is_upper, typename T, typename TPtr, int DIM_X, int DIM_Y>
1330+ template <bool copy_from_C_to_W_C,
1331+ bool is_upper,
1332+ bool HERM,
1333+ typename T,
1334+ typename TPtr,
1335+ int DIM_X,
1336+ int DIM_Y>
13311337ROCBLAS_KERNEL (DIM_X* DIM_Y)
1332- rocblas_copy_triangular_excluding_diagonal_kernel (rocblas_int n,
1333- TPtr d_C,
1334- rocblas_int ldc,
1335- rocblas_stride stride_C,
1336- T* W_C,
1337- rocblas_int batch_count)
1338+ rocblas_copy_triangular_syrk_herk_kernel (rocblas_int n,
1339+ TPtr d_C,
1340+ rocblas_int ldc,
1341+ rocblas_stride stride_C,
1342+ T* W_C,
1343+ rocblas_int batch_count)
13381344{
13391345 uint32_t batch = blockIdx.z ;
13401346
@@ -1379,19 +1385,24 @@ rocblas_copy_triangular_excluding_diagonal_kernel(rocblas_int n,
13791385 }
13801386 }
13811387
1388+ // When copying back to C, we need to zero-out diagonal imaginary
1389+ if constexpr (HERM && !copy_from_C_to_W_C)
1390+ if (row == col && row < n)
1391+ C[row + row * int64_t (ldc)] = std::real (C[row + row * int64_t (ldc)]);
1392+
13821393#if DEVICE_GRID_YZ_16BIT
13831394 }
13841395#endif
13851396}
13861397
1387- template <bool copy_from_C_to_W_C, bool is_upper, typename T, typename TPtr>
1388- rocblas_status rocblas_copy_triangular_excluding_diagonal (rocblas_handle handle,
1389- rocblas_int n,
1390- TPtr C,
1391- rocblas_int ldc,
1392- rocblas_stride stride_C,
1393- T* W_C,
1394- rocblas_int batch_count)
1398+ template <bool copy_from_C_to_W_C, bool is_upper, bool HERM, typename T, typename TPtr>
1399+ rocblas_status rocblas_copy_triangular_syrk_herk (rocblas_handle handle,
1400+ rocblas_int n,
1401+ TPtr C,
1402+ rocblas_int ldc,
1403+ rocblas_stride stride_C,
1404+ T* W_C,
1405+ rocblas_int batch_count)
13951406{
13961407 hipStream_t rocblas_stream = handle->get_stream ();
13971408
@@ -1405,12 +1416,13 @@ rocblas_status rocblas_copy_triangular_excluding_diagonal(rocblas_handle handle,
14051416 dim3 gridDim ((n - 1 ) / blockDim.x + 1 , (n - 1 ) / blockDim.y + 1 , batches);
14061417
14071418 // Launch kernel
1408- ROCBLAS_LAUNCH_KERNEL ((rocblas_copy_triangular_excluding_diagonal_kernel<copy_from_C_to_W_C,
1409- is_upper,
1410- T,
1411- TPtr,
1412- DIM_X,
1413- DIM_Y>),
1419+ ROCBLAS_LAUNCH_KERNEL ((rocblas_copy_triangular_syrk_herk_kernel<copy_from_C_to_W_C,
1420+ is_upper,
1421+ HERM,
1422+ T,
1423+ TPtr,
1424+ DIM_X,
1425+ DIM_Y>),
14141426 gridDim,
14151427 blockDim,
14161428 0 ,
0 commit comments