diff --git a/CHANGELOG.md b/CHANGELOG.md index e0dc4481..2701cea4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,41 @@ Full documentation for hipSOLVER is available at the [hipSOLVER Documentation](https://rocm.docs.amd.com/projects/hipSOLVER/en/latest/index.html). +## (Unreleased) hipSOLVER + +### Added + +* Added compatibility-only functions + * csrlsvqr + * hipsolverSpCcsrlsvqr, hipsolverSpZcsrlsvqr + +### Changed +### Removed +### Optimized +### Resolved issues + +* Corrected the value of `lwork` returned by various `bufferSize` functions to be consistent with NVIDIA cuSOLVER. The following functions will + now return `lwork` such that the workspace size (in bytes) is `sizeof(T) * lwork`, rather than `lwork`. To restore the original behavior, set + environment variable `HIPSOLVER_BUFFERSIZE_RETURN_BYTES`. + * hipsolverXorgbr_bufferSize, hipsolverXorgqr_bufferSize, hipsolverXorgtr_bufferSize, hipsolverXormqr_bufferSize, hipsolverXormtr_bufferSize, + hipsolverXgesvd_bufferSize, hipsolverXgesvdj_bufferSize, hipsolverXgesvdBatched_bufferSize, hipsolverXgesvdaStridedBatched_bufferSize, + hipsolverXsyevd_bufferSize, hipsolverXsyevdx_bufferSize, hipsolverXsyevj_bufferSize, hipsolverXsyevjBatched_bufferSize, + hipsolverXsygvd_bufferSize, hipsolverXsygvdx_bufferSize, hipsolverXsygvj_bufferSize, hipsolverXsytrd_bufferSize, hipsolverXsytrf_bufferSize + +### Known issues +### Upcoming changes + + +## hipSOLVER 2.5.0 for ROCm 6.5.0 + +### Upcoming changes + +* With the rocSOLVER backend, the bufferSize methods are currently outputting lwork such that the required workspace +size (in bytes) is lwork. In ROCm 7.0 this will change to make the rocSOLVER backend consistent with cuSOLVER. The +changed bufferSize methods will then return lwork such that the required workspace size (in bytes) is sizeof(T) * lwork, +where T is the used precision. This change will break ABI backward compatibility. + + ## hipSOLVER 2.4.0 for ROCm 6.4.0 ### Added @@ -12,6 +47,11 @@ Full documentation for hipSOLVER is available at the [hipSOLVER Documentation](h ### Upcoming changes * With the rocSOLVER backend, the bufferSize methods are currently outputting `lwork` such that the required workspace size (in bytes) is `lwork`. In ROCm 7.0 this will change to make the rocSOLVER backend consistent with cuSOLVER. The changed bufferSize methods will then return `lwork` so that the required workspace size (in bytes) is `sizeof(T) * lwork`, where T is the precision being used. This change will break ABI backward compatibility. +* With the rocSOLVER backend, the bufferSize methods are currently outputting lwork such that the required workspace + size (in bytes) is lwork. In ROCm 7.0 this will change to make the rocSOLVER backend consistent with cuSOLVER. The + changed bufferSize methods will then return lwork such that the required workspace size (in bytes) is sizeof(T) * lwork, + where T is the used precision. This change will break ABI backward compatibility. + ## hipSOLVER 2.3.0 for ROCm 6.3.0 diff --git a/clients/include/testing_gesvd.hpp b/clients/include/testing_gesvd.hpp index 2041523c..8edffe58 100644 --- a/clients/include/testing_gesvd.hpp +++ b/clients/include/testing_gesvd.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -322,7 +322,8 @@ void testing_gesvd_bad_arg() // int size_W; // hipsolver_gesvd_bufferSize(API, handle, left_svect, right_svect, m, n, dA.data(), lda, &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -350,7 +351,10 @@ void testing_gesvd_bad_arg() int size_W; hipsolver_gesvd_bufferSize( API, handle, left_svect, right_svect, m, n, dA.data(), lda, &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -1062,10 +1066,12 @@ void testing_gesvd(Arguments& argus) hipsolver_gesvd_bufferSize(API, handle, leftv, rightv, m, n, (T*)nullptr, lda, &w1); hipsolver_gesvd_bufferSize(API, handle, leftvT, rightvT, mT, nT, (T*)nullptr, lda, &w2); size_W = max(w1, w2); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -1089,7 +1095,7 @@ void testing_gesvd(Arguments& argus) device_strided_batch_vector dinfo(1, 1, 1, bc); device_strided_batch_vector dVT(size_VT, 1, stVT, bc); device_strided_batch_vector dUT(size_UT, 1, stUT, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_VT) CHECK_HIP_ERROR(dVT.memcheck()); if(size_UT) diff --git a/clients/include/testing_gesvda.hpp b/clients/include/testing_gesvda.hpp index a149e4a2..cc2bc4a8 100644 --- a/clients/include/testing_gesvda.hpp +++ b/clients/include/testing_gesvda.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -281,7 +281,8 @@ void testing_gesvda_bad_arg() // stV, // &size_W, // bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -344,7 +345,10 @@ void testing_gesvda_bad_arg() stV, &size_W, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -846,10 +850,12 @@ void testing_gesvda(Arguments& argus) stV, &size_W, bc); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -870,7 +876,7 @@ void testing_gesvda(Arguments& argus) device_strided_batch_vector dV(size_V, 1, stV, bc); device_strided_batch_vector dU(size_U, 1, stU, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_S) CHECK_HIP_ERROR(dS.memcheck()); if(size_V) diff --git a/clients/include/testing_gesvdj.hpp b/clients/include/testing_gesvdj.hpp index 06381a60..408b7e16 100644 --- a/clients/include/testing_gesvdj.hpp +++ b/clients/include/testing_gesvdj.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -278,7 +278,8 @@ void testing_gesvdj_bad_arg() // &size_W, // params, // bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -337,7 +338,10 @@ void testing_gesvdj_bad_arg() &size_W, params, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -888,10 +892,12 @@ void testing_gesvdj(Arguments& argus) &size_W, params, bc); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -912,7 +918,7 @@ void testing_gesvdj(Arguments& argus) device_strided_batch_vector dV(size_V, 1, stV, bc); device_strided_batch_vector dU(size_U, 1, stU, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_S) CHECK_HIP_ERROR(dS.memcheck()); if(size_V) diff --git a/clients/include/testing_orgbr_ungbr.hpp b/clients/include/testing_orgbr_ungbr.hpp index f82ccdfd..388c895a 100644 --- a/clients/include/testing_orgbr_ungbr.hpp +++ b/clients/include/testing_orgbr_ungbr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -88,7 +88,9 @@ void testing_orgbr_ungbr_bad_arg() int size_W; hipsolver_orgbr_ungbr_bufferSize( API, handle, side, m, n, k, dA.data(), lda, dIpiv.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -376,10 +378,12 @@ void testing_orgbr_ungbr(Arguments& argus) int size_W; hipsolver_orgbr_ungbr_bufferSize( API, handle, side, m, n, k, (T*)nullptr, lda, (T*)nullptr, &size_W); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -392,7 +396,7 @@ void testing_orgbr_ungbr(Arguments& argus) device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_orgqr_ungqr.hpp b/clients/include/testing_orgqr_ungqr.hpp index f3ff28ab..8ebdc7bf 100644 --- a/clients/include/testing_orgqr_ungqr.hpp +++ b/clients/include/testing_orgqr_ungqr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -79,7 +79,9 @@ void testing_orgqr_ungqr_bad_arg() int size_W; hipsolver_orgqr_ungqr_bufferSize(API, handle, m, n, k, dA.data(), lda, dIpiv.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -280,10 +282,12 @@ void testing_orgqr_ungqr(Arguments& argus) // memory size query is necessary int size_W; hipsolver_orgqr_ungqr_bufferSize(API, handle, m, n, k, (T*)nullptr, lda, (T*)nullptr, &size_W); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -296,7 +300,7 @@ void testing_orgqr_ungqr(Arguments& argus) device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_orgtr_ungtr.hpp b/clients/include/testing_orgtr_ungtr.hpp index fed1c860..fe466301 100644 --- a/clients/include/testing_orgtr_ungtr.hpp +++ b/clients/include/testing_orgtr_ungtr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -80,7 +80,9 @@ void testing_orgtr_ungtr_bad_arg() int size_W; hipsolver_orgtr_ungtr_bufferSize(API, handle, uplo, n, dA.data(), lda, dIpiv.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -286,10 +288,12 @@ void testing_orgtr_ungtr(Arguments& argus) // memory size query is necessary int size_W; hipsolver_orgtr_ungtr_bufferSize(API, handle, uplo, n, (T*)nullptr, lda, (T*)nullptr, &size_W); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -302,7 +306,7 @@ void testing_orgtr_ungtr(Arguments& argus) device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_ormqr_unmqr.hpp b/clients/include/testing_ormqr_unmqr.hpp index 91177c41..47bdbc33 100644 --- a/clients/include/testing_ormqr_unmqr.hpp +++ b/clients/include/testing_ormqr_unmqr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -186,7 +186,9 @@ void testing_ormqr_unmqr_bad_arg() int size_W; hipsolver_ormqr_unmqr_bufferSize( API, handle, side, trans, m, n, k, dA.data(), lda, dIpiv.data(), dC.data(), ldc, &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -528,10 +530,12 @@ void testing_ormqr_unmqr(Arguments& argus) (T*)nullptr, ldc, &size_W); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -546,7 +550,7 @@ void testing_ormqr_unmqr(Arguments& argus) device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_ormtr_unmtr.hpp b/clients/include/testing_ormtr_unmtr.hpp index 2b76c9e6..8fbe3281 100644 --- a/clients/include/testing_ormtr_unmtr.hpp +++ b/clients/include/testing_ormtr_unmtr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -237,7 +237,9 @@ void testing_ormtr_unmtr_bad_arg() dC.data(), ldc, &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -582,10 +584,12 @@ void testing_ormtr_unmtr(Arguments& argus) (T*)nullptr, ldc, &size_W); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -600,7 +604,7 @@ void testing_ormtr_unmtr(Arguments& argus) device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_syevd_heevd.hpp b/clients/include/testing_syevd_heevd.hpp index 091a13b5..32668c84 100644 --- a/clients/include/testing_syevd_heevd.hpp +++ b/clients/include/testing_syevd_heevd.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -123,7 +123,8 @@ void testing_syevd_heevd_bad_arg() // int size_W; // hipsolver_syevd_heevd_bufferSize( // API, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -155,7 +156,10 @@ void testing_syevd_heevd_bad_arg() int size_W; hipsolver_syevd_heevd_bufferSize( API, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -579,10 +583,12 @@ void testing_syevd_heevd(Arguments& argus) int size_W; hipsolver_syevd_heevd_bufferSize( API, handle, evect, uplo, n, (T*)nullptr, lda, (S*)nullptr, &size_W); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -595,7 +601,7 @@ void testing_syevd_heevd(Arguments& argus) // device device_strided_batch_vector dD(size_D, 1, stD, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_D) CHECK_HIP_ERROR(dD.memcheck()); CHECK_HIP_ERROR(dinfo.memcheck()); diff --git a/clients/include/testing_syevdx_heevdx.hpp b/clients/include/testing_syevdx_heevdx.hpp index a3fe44cf..3645922c 100644 --- a/clients/include/testing_syevdx_heevdx.hpp +++ b/clients/include/testing_syevdx_heevdx.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -271,7 +271,8 @@ void testing_syevdx_heevdx_bad_arg() // hNev.data(), // dW.data(), // &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -323,7 +324,10 @@ void testing_syevdx_heevdx_bad_arg() hNev.data(), dW.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -844,10 +848,13 @@ void testing_syevdx_heevdx(Arguments& argus) (int*)nullptr, (S*)nullptr, &size_Work); + size_t bytes_Work = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_Work + : sizeof(T) * size_Work; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_Work); + rocsolver_bench_inform(inform_mem_query, bytes_Work); return; } @@ -862,7 +869,8 @@ void testing_syevdx_heevdx(Arguments& argus) // device device_strided_batch_vector dW(size_W, 1, stW, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_Work, 1, size_Work, 1); // size_W accounts for bc + device_strided_batch_vector dWork( + bytes_Work, 1, bytes_Work, 1); // bytes_Work accounts for bc if(size_W) CHECK_HIP_ERROR(dW.memcheck()); CHECK_HIP_ERROR(dinfo.memcheck()); diff --git a/clients/include/testing_syevj_heevj.hpp b/clients/include/testing_syevj_heevj.hpp index b8cfd292..f67c67a1 100644 --- a/clients/include/testing_syevj_heevj.hpp +++ b/clients/include/testing_syevj_heevj.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -181,7 +181,8 @@ void testing_syevj_heevj_bad_arg() // int size_W; // hipsolver_syevj_heevj_bufferSize( // API, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W, params, bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -214,7 +215,10 @@ void testing_syevj_heevj_bad_arg() int size_W; hipsolver_syevj_heevj_bufferSize( API, STRIDED, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W, params, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -694,10 +698,12 @@ void testing_syevj_heevj(Arguments& argus) int size_W; hipsolver_syevj_heevj_bufferSize( API, STRIDED, handle, evect, uplo, n, (T*)nullptr, lda, (S*)nullptr, &size_W, params, bc); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -712,7 +718,7 @@ void testing_syevj_heevj(Arguments& argus) // device device_strided_batch_vector dD(size_D, 1, stD, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_D) CHECK_HIP_ERROR(dD.memcheck()); CHECK_HIP_ERROR(dinfo.memcheck()); diff --git a/clients/include/testing_sygvd_hegvd.hpp b/clients/include/testing_sygvd_hegvd.hpp index 4fcec2df..785fcd34 100644 --- a/clients/include/testing_sygvd_hegvd.hpp +++ b/clients/include/testing_sygvd_hegvd.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -248,7 +248,8 @@ void testing_sygvd_hegvd_bad_arg() // ldb, // dD.data(), // &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -286,7 +287,10 @@ void testing_sygvd_hegvd_bad_arg() int size_W; hipsolver_sygvd_hegvd_bufferSize( API, handle, itype, evect, uplo, n, dA.data(), lda, dB.data(), ldb, dD.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -858,10 +862,12 @@ void testing_sygvd_hegvd(Arguments& argus) ldb, (S*)nullptr, &size_W); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -879,7 +885,7 @@ void testing_sygvd_hegvd(Arguments& argus) // device_batch_vector dB(size_B, 1, bc); // device_strided_batch_vector dD(size_D, 1, stD, bc); // device_strided_batch_vector dInfo(1, 1, 1, bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc // if(size_A) // CHECK_HIP_ERROR(dA.memcheck()); // if(size_B) @@ -963,7 +969,7 @@ void testing_sygvd_hegvd(Arguments& argus) device_strided_batch_vector dB(size_B, 1, stB, bc); device_strided_batch_vector dD(size_D, 1, stD, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_B) diff --git a/clients/include/testing_sygvdx_hegvdx.hpp b/clients/include/testing_sygvdx_hegvdx.hpp index 4e7301c3..0a034f90 100644 --- a/clients/include/testing_sygvdx_hegvdx.hpp +++ b/clients/include/testing_sygvdx_hegvdx.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -372,7 +372,8 @@ void testing_sygvdx_hegvdx_bad_arg() // hNev.data(), // dW.data(), // &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -437,7 +438,10 @@ void testing_sygvdx_hegvdx_bad_arg() hNev.data(), dW.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -1170,10 +1174,13 @@ void testing_sygvdx_hegvdx(Arguments& argus) (int*)nullptr, (S*)nullptr, &size_Work); + size_t bytes_Work = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_Work + : sizeof(T) * size_Work; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_Work); + rocsolver_bench_inform(inform_mem_query, bytes_Work); return; } @@ -1188,7 +1195,8 @@ void testing_sygvdx_hegvdx(Arguments& argus) // device device_strided_batch_vector dW(size_W, 1, stW, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_Work, 1, size_Work, 1); // size_W accounts for bc + device_strided_batch_vector dWork( + bytes_Work, 1, bytes_Work, 1); // bytes_Work accounts for bc if(size_W) CHECK_HIP_ERROR(dW.memcheck()); CHECK_HIP_ERROR(dInfo.memcheck()); diff --git a/clients/include/testing_sygvj_hegvj.hpp b/clients/include/testing_sygvj_hegvj.hpp index cbfd60d3..e25af611 100644 --- a/clients/include/testing_sygvj_hegvj.hpp +++ b/clients/include/testing_sygvj_hegvj.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -259,7 +259,8 @@ void testing_sygvj_hegvj_bad_arg() // dD.data(), // &size_W, // params); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -309,7 +310,10 @@ void testing_sygvj_hegvj_bad_arg() dD.data(), &size_W, params); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -919,10 +923,12 @@ void testing_sygvj_hegvj(Arguments& argus) (S*)nullptr, &size_W, params); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -937,7 +943,7 @@ void testing_sygvj_hegvj(Arguments& argus) // device device_strided_batch_vector dD(size_D, 1, stD, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_D) CHECK_HIP_ERROR(dD.memcheck()); CHECK_HIP_ERROR(dInfo.memcheck()); diff --git a/clients/include/testing_sytrd_hetrd.hpp b/clients/include/testing_sytrd_hetrd.hpp index fbd46e7b..98885c5f 100644 --- a/clients/include/testing_sytrd_hetrd.hpp +++ b/clients/include/testing_sytrd_hetrd.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -211,7 +211,8 @@ void testing_sytrd_hetrd_bad_arg() // int size_W; // hipsolver_sytrd_hetrd_bufferSize( // API, handle, uplo, n, dA.data(), lda, dD.data(), dE.data(), dTau.data(), &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -250,7 +251,10 @@ void testing_sytrd_hetrd_bad_arg() int size_W; hipsolver_sytrd_hetrd_bufferSize( API, handle, uplo, n, dA.data(), lda, dD.data(), dE.data(), dTau.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -768,10 +772,12 @@ void testing_sytrd_hetrd(Arguments& argus) int size_W; hipsolver_sytrd_hetrd_bufferSize( API, handle, uplo, n, (T*)nullptr, lda, (S*)nullptr, (S*)nullptr, (T*)nullptr, &size_W); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -787,7 +793,7 @@ void testing_sytrd_hetrd(Arguments& argus) device_strided_batch_vector dE(size_E, 1, stE, bc); device_strided_batch_vector dTau(size_tau, 1, stP, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_D) CHECK_HIP_ERROR(dD.memcheck()); if(size_E) diff --git a/clients/include/testing_sytrf.hpp b/clients/include/testing_sytrf.hpp index 6705a238..2901b0fa 100644 --- a/clients/include/testing_sytrf.hpp +++ b/clients/include/testing_sytrf.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -101,7 +101,8 @@ void testing_sytrf_bad_arg() // int size_W; // hipsolver_sytrf_bufferSize(API, handle, n, dA.data(), lda, &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -131,7 +132,10 @@ void testing_sytrf_bad_arg() int size_W; hipsolver_sytrf_bufferSize(API, handle, n, dA.data(), lda, &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -509,10 +513,12 @@ void testing_sytrf(Arguments& argus) // memory size query is necessary int size_W; hipsolver_sytrf_bufferSize(API, handle, n, (T*)nullptr, lda, &size_W); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -528,7 +534,7 @@ void testing_sytrf(Arguments& argus) // device_batch_vector dA(size_A, 1, bc); // device_strided_batch_vector dIpiv(size_P, 1, stP, bc); // device_strided_batch_vector dInfo(1, 1, 1, bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc // if(size_A) // CHECK_HIP_ERROR(dA.memcheck()); // CHECK_HIP_ERROR(dInfo.memcheck()); @@ -594,7 +600,7 @@ void testing_sytrf(Arguments& argus) device_strided_batch_vector dA(size_A, 1, stA, bc); device_strided_batch_vector dIpiv(size_P, 1, stP, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_A) CHECK_HIP_ERROR(dA.memcheck()); CHECK_HIP_ERROR(dInfo.memcheck()); diff --git a/library/CMakeLists.txt b/library/CMakeLists.txt index 522201da..b7ffbf7d 100644 --- a/library/CMakeLists.txt +++ b/library/CMakeLists.txt @@ -1,5 +1,5 @@ # ######################################################################## -# Copyright (C) 2016-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2016-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/library/src/amd_detail/hipsolver.cpp b/library/src/amd_detail/hipsolver.cpp index dae75071..445695fd 100644 --- a/library/src/amd_detail/hipsolver.cpp +++ b/library/src/amd_detail/hipsolver.cpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1098,6 +1098,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1142,6 +1145,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1186,6 +1192,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1230,6 +1239,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1257,12 +1269,20 @@ hipsolverStatus_t hipsolverSorgbr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSorgbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1289,12 +1309,20 @@ hipsolverStatus_t hipsolverDorgbr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDorgbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1321,12 +1349,20 @@ hipsolverStatus_t hipsolverCungbr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverCungbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1359,12 +1395,20 @@ hipsolverStatus_t hipsolverZungbr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZungbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1401,6 +1445,9 @@ try rocsolver_sorgqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1431,6 +1478,9 @@ try rocsolver_dorgqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1467,6 +1517,9 @@ try rocsolver_cungqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1503,6 +1556,9 @@ try rocsolver_zungqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1529,12 +1585,20 @@ hipsolverStatus_t hipsolverSorgqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSorgqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1560,12 +1624,20 @@ hipsolverStatus_t hipsolverDorgqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDorgqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1591,12 +1663,20 @@ hipsolverStatus_t hipsolverCungqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverCungqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1627,12 +1707,20 @@ hipsolverStatus_t hipsolverZungqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZungqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1673,6 +1761,9 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1708,6 +1799,9 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1743,6 +1837,9 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1778,6 +1875,9 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1803,12 +1903,20 @@ hipsolverStatus_t hipsolverSorgtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSorgtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1833,12 +1941,20 @@ hipsolverStatus_t hipsolverDorgtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDorgtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1863,12 +1979,20 @@ hipsolverStatus_t hipsolverCungtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverCungtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1897,12 +2021,20 @@ hipsolverStatus_t hipsolverZungtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZungtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1957,6 +2089,9 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2007,6 +2142,9 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2057,6 +2195,9 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2107,6 +2248,9 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2137,12 +2281,20 @@ hipsolverStatus_t hipsolverSormqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSormqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2181,12 +2333,20 @@ hipsolverStatus_t hipsolverDormqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDormqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2225,12 +2385,20 @@ hipsolverStatus_t hipsolverCunmqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCunmqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2269,12 +2437,20 @@ hipsolverStatus_t hipsolverZunmqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZunmqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2334,6 +2510,9 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2384,6 +2563,9 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2434,6 +2616,9 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2484,6 +2669,9 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2514,12 +2702,20 @@ hipsolverStatus_t hipsolverSormtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSormtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2558,12 +2754,20 @@ hipsolverStatus_t hipsolverDormtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDormtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2602,12 +2806,20 @@ hipsolverStatus_t hipsolverCunmtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCunmtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2646,12 +2858,20 @@ hipsolverStatus_t hipsolverZunmtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZunmtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -4019,6 +4239,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4071,6 +4294,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4123,6 +4349,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4175,6 +4404,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4216,13 +4448,19 @@ try work = rwork + std::min(m, n); } - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) { @@ -4282,13 +4520,19 @@ try work = rwork + std::min(m, n); } - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) { @@ -4348,13 +4592,19 @@ try work = (hipFloatComplex*)(rwork + std::min(m, n)); } - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) { @@ -4414,13 +4664,19 @@ try work = (hipDoubleComplex*)(rwork + std::min(m, n)); } - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) { @@ -4506,6 +4762,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4572,6 +4831,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4638,6 +4900,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4704,6 +4969,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4742,12 +5010,20 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -4805,12 +5081,20 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -4868,12 +5152,20 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -4931,12 +5223,20 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5028,6 +5328,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5099,6 +5402,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5170,6 +5476,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5241,6 +5550,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5279,7 +5591,12 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSgesvdjBatched_bufferSize((rocblas_handle)handle, @@ -5296,7 +5613,10 @@ try &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5359,7 +5679,12 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDgesvdjBatched_bufferSize((rocblas_handle)handle, @@ -5376,7 +5701,10 @@ try &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5439,7 +5767,12 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCgesvdjBatched_bufferSize((rocblas_handle)handle, @@ -5456,7 +5789,10 @@ try &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5519,7 +5855,12 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZgesvdjBatched_bufferSize((rocblas_handle)handle, @@ -5536,7 +5877,10 @@ try &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5644,6 +5988,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5726,6 +6073,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5808,6 +6158,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5889,6 +6242,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5943,7 +6299,10 @@ try if(std::min(m, n) * batch_count > 0) work = (float*)(ifail + std::min(m, n) * batch_count); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { @@ -5965,7 +6324,10 @@ try strideV, &lwork, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(int) * batch_count, @@ -6051,7 +6413,10 @@ try if(std::min(m, n) * batch_count > 0) work = (double*)(ifail + std::min(m, n) * batch_count); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { @@ -6073,7 +6438,10 @@ try strideV, &lwork, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(int) * batch_count, @@ -6159,7 +6527,10 @@ try if(std::min(m, n) * batch_count > 0) work = (hipFloatComplex*)(ifail + std::min(m, n) * batch_count); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { @@ -6181,7 +6552,10 @@ try strideV, &lwork, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(int) * batch_count, @@ -6267,7 +6641,10 @@ try if(std::min(m, n) * batch_count > 0) work = (hipDoubleComplex*)(ifail + std::min(m, n) * batch_count); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { @@ -6289,7 +6666,10 @@ try strideV, &lwork, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(int) * batch_count, @@ -8421,6 +8801,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8473,6 +8856,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8525,6 +8911,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8577,6 +8966,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8611,13 +9003,19 @@ try if(n > 0) work = E + n; - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSsyevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); if(!mem) @@ -8661,13 +9059,19 @@ try if(n > 0) work = E + n; - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDsyevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); if(!mem) @@ -8711,13 +9115,19 @@ try if(n > 0) work = (hipFloatComplex*)(E + n); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCheevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); if(!mem) @@ -8761,13 +9171,19 @@ try if(n > 0) work = (hipDoubleComplex*)(E + n); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZheevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); if(!mem) @@ -8834,6 +9250,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8890,6 +9309,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8946,6 +9368,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9002,6 +9427,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9034,12 +9462,20 @@ hipsolverStatus_t hipsolverSsyevdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsyevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -9083,12 +9519,20 @@ hipsolverStatus_t hipsolverDsyevdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsyevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -9132,12 +9576,20 @@ hipsolverStatus_t hipsolverCheevdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCheevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -9181,12 +9633,20 @@ hipsolverStatus_t hipsolverZheevdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZheevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -9251,6 +9711,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9303,6 +9766,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9355,6 +9821,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9407,6 +9876,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9439,12 +9911,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsyevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9490,12 +9970,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsyevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9541,12 +10029,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCheevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9592,12 +10088,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZheevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9668,6 +10172,9 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9724,6 +10231,9 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9780,6 +10290,9 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9836,6 +10349,9 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9869,12 +10385,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsyevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9925,12 +10449,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsyevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9981,12 +10513,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCheevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -10037,12 +10577,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZheevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -10119,6 +10667,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10177,6 +10728,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10235,6 +10789,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10293,6 +10850,9 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10330,13 +10890,19 @@ try if(n > 0) work = E + n; - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSsygvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); if(!mem) @@ -10386,13 +10952,19 @@ try if(n > 0) work = E + n; - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDsygvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); if(!mem) @@ -10442,13 +11014,19 @@ try if(n > 0) work = (hipFloatComplex*)(E + n); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverChegvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); if(!mem) @@ -10498,13 +11076,19 @@ try if(n > 0) work = (hipDoubleComplex*)(E + n); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZhegvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); if(!mem) @@ -10580,6 +11164,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10642,6 +11229,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10704,6 +11294,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10766,6 +11359,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10801,7 +11397,12 @@ hipsolverStatus_t hipsolverSsygvdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsygvdx_bufferSize((rocblas_handle)handle, @@ -10821,7 +11422,10 @@ try nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -10871,7 +11475,12 @@ hipsolverStatus_t hipsolverDsygvdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsygvdx_bufferSize((rocblas_handle)handle, @@ -10891,7 +11500,10 @@ try nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -10941,7 +11553,12 @@ hipsolverStatus_t hipsolverChegvdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverChegvdx_bufferSize((rocblas_handle)handle, @@ -10961,7 +11578,10 @@ try nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -11011,7 +11631,12 @@ hipsolverStatus_t hipsolverZhegvdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZhegvdx_bufferSize((rocblas_handle)handle, @@ -11031,7 +11656,10 @@ try nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -11104,6 +11732,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11161,6 +11792,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11218,6 +11852,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11275,6 +11912,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11310,13 +11950,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsygvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - CHECK_ROCBLAS_ERROR( - hipsolverManageWorkspace((rocblas_handle)handle, lwork + sizeof(float) * n)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -11367,13 +12014,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsygvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - CHECK_ROCBLAS_ERROR( - hipsolverManageWorkspace((rocblas_handle)handle, lwork + sizeof(float) * n)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -11424,13 +12078,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverChegvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - CHECK_ROCBLAS_ERROR( - hipsolverManageWorkspace((rocblas_handle)handle, lwork + sizeof(float) * n)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -11481,13 +12142,20 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZhegvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - CHECK_ROCBLAS_ERROR( - hipsolverManageWorkspace((rocblas_handle)handle, lwork + sizeof(float) * n)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -11548,6 +12216,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11592,6 +12263,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11636,6 +12310,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11680,6 +12357,9 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11707,12 +12387,20 @@ hipsolverStatus_t hipsolverSsytrd(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSsytrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -11739,12 +12427,20 @@ hipsolverStatus_t hipsolverDsytrd(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDsytrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -11771,12 +12467,20 @@ hipsolverStatus_t hipsolverChetrd(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverChetrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -11809,12 +12513,20 @@ hipsolverStatus_t hipsolverZhetrd(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZhetrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -11851,6 +12563,9 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11881,6 +12596,9 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11911,6 +12629,9 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11941,6 +12662,9 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11966,12 +12690,20 @@ hipsolverStatus_t hipsolverSsytrf(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status(rocsolver_ssytrf( @@ -11994,12 +12726,20 @@ hipsolverStatus_t hipsolverDsytrf(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status(rocsolver_dsytrf( @@ -12022,12 +12762,20 @@ hipsolverStatus_t hipsolverCsytrf(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverCsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status(rocsolver_csytrf((rocblas_handle)handle, @@ -12055,12 +12803,20 @@ hipsolverStatus_t hipsolverZsytrf(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status(rocsolver_zsytrf((rocblas_handle)handle,