@@ -135,9 +135,9 @@ void selectColsIf(const raft::handle_t& handle,
135135 raft::linalg::map (
136136 handle,
137137 raft::make_const_mdspan (mask),
138+ raft::make_const_mdspan (rangeVec.view ()),
138139 rangeVec.view (),
139- [] __device__ (index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1 ; },
140- rangeVec.view ());
140+ [] __device__ (index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1 ; });
141141 thrust::sort (rmm::exec_policy (stream),
142142 rangeVec.data_handle (),
143143 rangeVec.data_handle () + rangeVec.size (),
@@ -172,11 +172,11 @@ void truncEig(
172172 }
173173 if (eigVectorTrunc.has_value () && ncols > eigVectorTrunc->extent (1 ))
174174 raft::matrix::truncZeroOrigin (eigVectorin.data_handle (),
175- n_rows ,
175+ nrows ,
176176 eigVectorTrunc->data_handle (),
177177 nrows,
178178 eigVectorTrunc->extent (1 ),
179- stream );
179+ handle. get_stream () );
180180}
181181
182182// C = A * B
@@ -447,7 +447,7 @@ bool eigh(const raft::handle_t& handle,
447447
448448 raft::linalg::eig_dc (handle, raft::make_const_mdspan (F.view ()), Fvecs.view (), eigVals);
449449 raft::linalg::gemm (handle, Ri.view (), Fvecs.view (), eigVecs);
450- return cho_success
450+ return cho_success;
451451}
452452
453453/* *
@@ -604,8 +604,10 @@ void lobpcg(
604604 auto eigVectorBuffer = rmm::device_uvector<value_t >(size_x * size_x, stream); // rmm because of resize
605605 auto eigVectorView = raft::make_device_matrix_view<value_t , index_t , raft::col_major>(eigVectorBuffer.data (), size_x, size_x);
606606 auto eigLambda = raft::make_device_vector<value_t , index_t >(handle, size_x);
607- eigh (handle, gramXAX.view (), eigVectorView, eigLambda.view ());
608- truncEig (handle, eigVectorView, eigLambda.view (), size_x, largest);
607+ std::optional<raft::device_matrix_view<value_t , index_t , raft::col_major>> empty_matrix_opt = std::nullopt ;
608+ eigh (handle, gramXAX.view (), empty_matrix_opt, eigVectorView, eigLambda.view ());
609+
610+ truncEig (handle, eigVectorView, empty_matrix_opt, eigLambda.view (), largest);
609611 // Slice not needed for first eigh
610612 // raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0,
611613 // eigVectorFull.extent(0), size_x));
@@ -623,6 +625,9 @@ void lobpcg(
623625 auto identView = raft::make_device_matrix_view<value_t , index_t , raft::col_major>(
624626 ident.data (), size_x, size_x);
625627 raft::matrix::eye (handle, identView);
628+ auto identSizeX = raft::make_device_matrix<value_t , index_t , raft::col_major>(
629+ handle, size_x, size_x);
630+ raft::matrix::eye (handle, identSizeX.view ());
626631
627632 auto Pbuffer = rmm::device_uvector<value_t >(0 , stream);
628633 auto APbuffer = rmm::device_uvector<value_t >(0 , stream);
@@ -646,6 +651,8 @@ void lobpcg(
646651
647652 auto aux = raft::make_device_matrix<value_t , index_t , raft::col_major>(
648653 handle, n, size_x);
654+ // auto aux_sum = raft::make_device_vector<value_t, index_t>(handle, size_x);
655+ auto residual_norms = raft::make_device_vector<value_t , index_t >(handle, size_x);
649656 std::int32_t iteration_number = -1 ;
650657 bool restart = true ;
651658 bool explicitGramFlag = false ;
@@ -664,9 +671,8 @@ void lobpcg(
664671 raft::linalg::subtract (
665672 handle, raft::make_const_mdspan (AX.view ()), raft::make_const_mdspan (aux.view ()), R.view ());
666673
667- auto aux_sum = raft::make_device_vector<value_t , index_t >(handle, size_x);
668674 raft::linalg::reduce (
669- aux_sum .data_handle (),
675+ residual_norms .data_handle (),
670676 R.data_handle (),
671677 size_x,
672678 n,
@@ -677,8 +683,7 @@ void lobpcg(
677683 false ,
678684 raft::sq_op ());
679685
680- auto residual_norms = raft::make_device_vector<value_t , index_t >(handle, size_x);
681- raft::linalg::sqrt (handle, raft::make_const_mdspan (aux_sum.view ()), residual_norms.view ());
686+ // TODO check sqop of reduce raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());
682687
683688 // cupy where & active_mask
684689 raft::linalg::unary_op (handle,
@@ -720,7 +725,7 @@ void lobpcg(
720725 selectColsIf (handle, APView, active_mask.view (), activeAPView);
721726 if (B_opt.has_value ()) {
722727 activeBPView = raft::make_device_matrix_view<value_t , index_t , col_major>(activeBPbuffer.data (), n, currentBlockSize);
723- selectColsIf (handle, BPbuffer. view () , active_mask.view (), activeBPView);
728+ selectColsIf (handle, BPView , active_mask.view (), activeBPView);
724729 }
725730 }
726731 if (M_opt.has_value ()) {
@@ -823,7 +828,7 @@ void lobpcg(
823828
824829 if (!B_opt.has_value ()) {
825830 // Shared memory assignments to simplify the code
826- BXView = X. view () ;
831+ BXView = X;
827832 activeBRView = activeR.view ();
828833 if (!restart)
829834 activeBPView = activePView;
@@ -906,9 +911,9 @@ void lobpcg(
906911 auto gramB = raft::make_device_matrix<value_t , index_t , col_major>(handle, gramDim, gramDim);
907912 auto gramAView = gramA.view ();
908913 auto gramBView = gramB.view ();
909- auto eigLambdaTemp = raft::make_device_vector_view <value_t , index_t >(handle, gramDim);
914+ auto eigLambdaTemp = raft::make_device_vector <value_t , index_t >(handle, gramDim);
910915 auto eigVectorTemp =
911- raft::make_device_matrix_view <value_t , index_t , raft::col_major>(handle, gramDim, gramDim);
916+ raft::make_device_matrix <value_t , index_t , raft::col_major>(handle, gramDim, gramDim);
912917 auto eigLambdaTempView = eigLambdaTemp.view ();
913918 auto eigVectorTempView = eigVectorTemp.view ();
914919 eigVectorBuffer.resize (gramDim * size_x, stream);
@@ -927,19 +932,19 @@ void lobpcg(
927932 handle, currentBlockSize, currentBlockSize);
928933 // create transpose mat
929934 auto gramXAPT = raft::make_device_matrix<value_t , index_t , col_major>(
930- handle, gramXAPT .extent (1 ), gramXAPT .extent (0 ));
935+ handle, gramXAP .extent (1 ), gramXAP .extent (0 ));
931936 auto gramXART = raft::make_device_matrix<value_t , index_t , col_major>(
932- handle, gramXART .extent (1 ), gramXART .extent (0 ));
937+ handle, gramXAR .extent (1 ), gramXAR .extent (0 ));
933938 auto gramRAPT = raft::make_device_matrix<value_t , index_t , col_major>(
934- handle, gramRAPT .extent (1 ), gramRAPT .extent (0 ));
939+ handle, gramRAP .extent (1 ), gramRAP .extent (0 ));
935940 auto gramXBPT = raft::make_device_matrix<value_t , index_t , col_major>(
936- handle, gramXBPT .extent (1 ), gramXBPT .extent (0 ));
941+ handle, gramXBP .extent (1 ), gramXBP .extent (0 ));
937942 auto gramXBRT = raft::make_device_matrix<value_t , index_t , col_major>(
938- handle, gramXBRT .extent (1 ), gramXBRT .extent (0 ));
943+ handle, gramXBR .extent (1 ), gramXBR .extent (0 ));
939944 auto gramRBPT = raft::make_device_matrix<value_t , index_t , col_major>(
940- handle, gramRBPT .extent (1 ), gramRBPT .extent (0 ));
945+ handle, gramRBP .extent (1 ), gramRBP .extent (0 ));
941946 raft::linalg::transpose (handle, gramXAR.view (), gramXART.view ());
942- raft::linalg::transpose (handle, gramXVR .view (), gramXBRT.view ());
947+ raft::linalg::transpose (handle, gramXBR .view (), gramXBRT.view ());
943948
944949 if (!restart) {
945950 raft::linalg::gemm (handle,
@@ -1005,19 +1010,19 @@ void lobpcg(
10051010 gramBView =
10061011 raft::make_device_matrix_view<value_t , index_t , col_major>(gramB.data_handle (), n, n);
10071012
1008- bmat (handle, gramAView, A_blocks);
1009- bmat (handle, gramBView, B_blocks);
1013+ bmat (handle, gramAView, A_blocks, 3 );
1014+ bmat (handle, gramBView, B_blocks, 3 );
10101015
10111016 bool eig_sucess =
1012- eigh (handle, gramA , std::make_optional (gramBView), eigVectorTempView, eigLambdaTempView);
1017+ eigh (handle, gramAView , std::make_optional (gramBView), eigVectorTempView, eigLambdaTempView);
10131018 if (!eig_sucess) restart = true ;
10141019 }
10151020 if (restart) {
10161021 gramDim = gramXAX.extent (1 ) + gramXAR.extent (1 );
10171022 std::vector<raft::device_matrix_view<value_t , index_t , col_major>> A_blocks = {
1018- gramXAX, gramXAR, gramXART, gramRAR};
1023+ gramXAX. view () , gramXAR. view () , gramXART. view () , gramRAR. view () };
10191024 std::vector<raft::device_matrix_view<value_t , index_t , col_major>> B_blocks = {
1020- gramXBX, gramXBR, gramXBRT, gramRBR};
1025+ gramXBX. view () , gramXBR. view () , gramXBRT. view () , gramRBR. view () };
10211026 gramAView = raft::make_device_matrix_view<value_t , index_t , col_major>(
10221027 gramA.data_handle (), gramDim, gramDim);
10231028 gramBView = raft::make_device_matrix_view<value_t , index_t , col_major>(
@@ -1026,8 +1031,8 @@ void lobpcg(
10261031 raft::make_device_vector_view<value_t , index_t >(eigLambdaTempView.data_handle (), gramDim);
10271032 eigVectorTempView = raft::make_device_matrix_view<value_t , index_t , col_major>(
10281033 eigVectorTempView.data_handle (), gramDim, gramDim);
1029- bmat (handle, gramAView, A_blocks);
1030- bmat (handle, gramBView, B_blocks);
1034+ bmat (handle, gramAView, A_blocks, 2 );
1035+ bmat (handle, gramBView, B_blocks, 2 );
10311036 bool eig_sucess = eigh (
10321037 handle, gramAView, std::make_optional (gramBView), eigVectorTempView, eigLambdaTempView);
10331038 ASSERT (eig_sucess, " lobpcg: eigh has failed in lobpcg iterations" );
@@ -1048,20 +1053,20 @@ void lobpcg(
10481053 auto app = raft::make_device_matrix<value_t , index_t , raft::col_major>(handle, n, size_x);
10491054 if (B_opt.has_value ()) {
10501055 auto bpp = raft::make_device_matrix<value_t , index_t , raft::col_major>(handle, n, size_x);
1051- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorX.view (),
1056+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorX.view (),
10521057 raft::matrix::slice_coordinates<index_t >(0 , 0 , size_x, size_x));
10531058 if (!restart) {
1054- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorR.view (),
1059+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorR.view (),
10551060 raft::matrix::slice_coordinates<index_t >(size_x, 0 , size_x + currentBlockSize, size_x));
1056- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorP.view (),
1061+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorP.view (),
10571062 raft::matrix::slice_coordinates<index_t >(size_x + currentBlockSize, 0 , gramDim, size_x));
10581063 } else {
1059- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorR.view (),
1064+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorR.view (),
10601065 raft::matrix::slice_coordinates<index_t >(size_x, 0 , gramDim, size_x));
10611066 }
10621067
1063- raft::linalg::gemm (handle, activeRView , eigBlockVectorR.view (), pp.view ());
1064- raft::linalg::gemm (handle, activeARView , eigBlockVectorR.view (), app.view ());
1068+ raft::linalg::gemm (handle, activeR. view () , eigBlockVectorR.view (), pp.view ());
1069+ raft::linalg::gemm (handle, activeAR. view () , eigBlockVectorR.view (), app.view ());
10651070 raft::linalg::gemm (handle, activeBRView, eigBlockVectorR.view (), bpp.view ());
10661071 if (!restart) {
10671072 raft::linalg::gemm (handle, activePView, eigBlockVectorP.view (), pp.view (), one, one);
@@ -1087,20 +1092,20 @@ void lobpcg(
10871092 raft::copy (AX.data_handle (), app.data_handle (), app.size (), stream);
10881093 raft::copy (BXView.data_handle (), bpp.data_handle (), bpp.size (), stream);
10891094 } else {
1090- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorX.view (),
1095+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorX.view (),
10911096 raft::matrix::slice_coordinates<index_t >(0 , 0 , size_x, size_x));
10921097 if (!restart) {
1093- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorR.view (),
1098+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorR.view (),
10941099 raft::matrix::slice_coordinates<index_t >(size_x, 0 , size_x + currentBlockSize, size_x));
1095- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorP.view (),
1100+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorP.view (),
10961101 raft::matrix::slice_coordinates<index_t >(size_x + currentBlockSize, 0 , gramDim, size_x));
10971102 } else {
1098- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorR.view (),
1103+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorR.view (),
10991104 raft::matrix::slice_coordinates<index_t >(size_x, 0 , gramDim, size_x));
11001105 }
11011106
1102- raft::linalg::gemm (handle, activeRView , eigBlockVectorR.view (), pp.view ());
1103- raft::linalg::gemm (handle, activeARView , eigBlockVectorR.view (), app.view ());
1107+ raft::linalg::gemm (handle, activeR. view () , eigBlockVectorR.view (), pp.view ());
1108+ raft::linalg::gemm (handle, activeAR. view () , eigBlockVectorR.view (), app.view ());
11041109 if (!restart) {
11051110 raft::linalg::gemm (handle, activePView, eigBlockVectorP.view (), pp.view (), one, one);
11061111 raft::linalg::gemm (handle, activeAPView, eigBlockVectorP.view (), app.view (), one, one);
@@ -1121,12 +1126,31 @@ void lobpcg(
11211126 }
11221127 }
11231128
1124- if (B_opt.has_value ()) { // Using blockVectorR instead of aux
1125- raft::copy (R .data_handle (), BXView.data_handle (), BXView.size (), stream);
1129+ if (B_opt.has_value ()) {
1130+ raft::copy (aux .data_handle (), BXView.data_handle (), BXView.size (), stream);
11261131 } else {
1127- raft::copy (R.data_handle (), X.data_handle (), X.size (), stream);
1132+ raft::copy (aux.data_handle (), X.data_handle (), X.size (), stream);
1133+ }
1134+ raft::linalg::binary_mult_skip_zero (handle, aux.view (), make_const_mdspan (eigLambda.view ()), raft::linalg::Apply::ALONG_ROWS);
1135+
1136+ raft::linalg::subtract (
1137+ handle, raft::make_const_mdspan (AX.view ()), raft::make_const_mdspan (aux.view ()), R.view ());
1138+
1139+ raft::linalg::reduce (
1140+ residual_norms.data_handle (),
1141+ R.data_handle (),
1142+ size_x,
1143+ n,
1144+ value_t (0 ),
1145+ false ,
1146+ true ,
1147+ stream,
1148+ false ,
1149+ raft::sq_op ());
1150+ // TODO check reduce sqrt postop raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());
1151+
1152+ if (verbosityLevel > 0 ) {
1153+ // / TODO add verb
11281154 }
1129- raft::linalg::binary_mult_skip_zero (handle, R.view (), make_const_mdspan (eigLambda.view ()), linalg::Apply::ALONG_ROWS);
1130- raft::linalg::gemm (handle, AX.view (),)
11311155}
11321156}; // namespace raft::sparse::solver::detail
0 commit comments