5050#include < KokkosBlas1_nrm1.hpp>
5151#include < KokkosBlas1_nrm2.hpp>
5252
53+ #include < std_algorithms/Kokkos_ExclusiveScan.hpp>
54+
5355#include < memory>
5456
5557#include " Ifpack2_BlockHelper.hpp"
@@ -1948,8 +1950,7 @@ void performSymbolicPhase(const Teuchos::RCP<const typename BlockHelperDetails::
19481950
19491951 using impl_type = BlockHelperDetails::ImplType<MatrixType>;
19501952
1951- using execution_space = typename impl_type::execution_space;
1952- using host_execution_space = typename impl_type::host_execution_space;
1953+ using execution_space = typename impl_type::execution_space;
19531954
19541955 using local_ordinal_type = typename impl_type::local_ordinal_type;
19551956 using global_ordinal_type = typename impl_type::global_ordinal_type;
@@ -1958,10 +1959,11 @@ void performSymbolicPhase(const Teuchos::RCP<const typename BlockHelperDetails::
19581959 using size_type_1d_view = typename impl_type::size_type_1d_view;
19591960 using vector_type_3d_view = typename impl_type::vector_type_3d_view;
19601961 using vector_type_4d_view = typename impl_type::vector_type_4d_view;
1961- using internal_vector_type_3d_view = typename impl_type::internal_vector_type_3d_view;
19621962 using crs_matrix_type = typename impl_type::tpetra_crs_matrix_type;
19631963 using block_crs_matrix_type = typename impl_type::tpetra_block_crs_matrix_type;
19641964 using btdm_scalar_type_3d_view = typename impl_type::btdm_scalar_type_3d_view;
1965+ using internal_vector_type_3d_view = typename impl_type::internal_vector_type_3d_view;
1966+ using lo_traits = Tpetra::Details::OrdinalTraits<local_ordinal_type>;
19651967
19661968 constexpr int vector_length = impl_type::vector_length;
19671969 constexpr int internal_vector_length = impl_type::internal_vector_length;
@@ -1975,90 +1977,97 @@ void performSymbolicPhase(const Teuchos::RCP<const typename BlockHelperDetails::
19751977 TEUCHOS_ASSERT (hasBlockCrsMatrix || g->getLocalNumRows () != 0 );
19761978 const local_ordinal_type blocksize = hasBlockCrsMatrix ? A->getBlockSize () : A->getLocalNumRows () / g->getLocalNumRows ();
19771979
1978- // mirroring to host
1979- const auto partptr = Kokkos::create_mirror_view_and_copy (Kokkos::HostSpace (), interf.partptr );
1980- const auto lclrow = Kokkos::create_mirror_view_and_copy (Kokkos::HostSpace (), interf.lclrow );
1981- const auto rowidx2part = Kokkos::create_mirror_view_and_copy (Kokkos::HostSpace (), interf.rowidx2part );
1982- const auto part2rowidx0 = Kokkos::create_mirror_view_and_copy (Kokkos::HostSpace (), interf.part2rowidx0 );
1983- const auto packptr = Kokkos::create_mirror_view_and_copy (Kokkos::HostSpace (), interf.packptr );
1980+ const auto partptr = interf.partptr ;
1981+ const auto lclrow = interf.lclrow ;
1982+ const auto rowidx2part = interf.rowidx2part ;
1983+ const auto part2rowidx0 = interf.part2rowidx0 ;
1984+ const auto packptr = interf.packptr ;
19841985
1985- const local_ordinal_type nrows = partptr (partptr.extent (0 ) - 1 );
1986+ // TODO: add nrows as a member of part interface?
1987+ const local_ordinal_type nrows = Kokkos::create_mirror_view_and_copy (
1988+ Kokkos::HostSpace (), Kokkos::subview (partptr, partptr.extent (0 ) - 1 ))();
19861989
1987- Kokkos::View<local_ordinal_type *, host_execution_space > col2row (" col2row" , A->getLocalNumCols ());
1990+ Kokkos::View<local_ordinal_type *, execution_space > col2row (" col2row" , A->getLocalNumCols ());
19881991
19891992 // find column to row map on host
19901993
1991- Kokkos::deep_copy (col2row, Teuchos::OrdinalTraits<local_ordinal_type>::invalid ());
1994+ Kokkos::deep_copy (execution_space (), col2row, Teuchos::OrdinalTraits<local_ordinal_type>::invalid ());
19921995 {
1993- const auto rowmap = g->getRowMap ();
1994- const auto colmap = g->getColMap ();
1995- const auto dommap = g->getDomainMap ();
1996- TEUCHOS_ASSERT (!(rowmap.is_null () || colmap.is_null () || dommap.is_null ()));
1997- rowmap->lazyPushToHost ();
1998- colmap->lazyPushToHost ();
1999- dommap->lazyPushToHost ();
2000-
2001- #if !defined(__CUDA_ARCH__) && !defined(__HIP_DEVICE_COMPILE__) && !defined(__SYCL_DEVICE_ONLY__)
2002- const Kokkos::RangePolicy<host_execution_space> policy (0 , nrows);
1996+ TEUCHOS_ASSERT (!(g->getRowMap ().is_null () || g->getColMap ().is_null () || g->getDomainMap ().is_null ()));
1997+ #if defined(BLOCKTRIDICONTAINER_DEBUG)
1998+ {
1999+ // On host: check that row, col, domain maps are consistent
2000+ auto rowmapHost = g->getRowMap ();
2001+ auto colmapHost = g->getColMap ();
2002+ auto dommapHost = g->getDomainMap ();
2003+ for (local_ordinal_type lr = 0 ; lr < nrows; lr++) {
2004+ const global_ordinal_type gid = rowmapHost->getGlobalElement (lr);
2005+ TEUCHOS_ASSERT (gid != Teuchos::OrdinalTraits<global_ordinal_type>::invalid ());
2006+ if (dommapHost->isNodeGlobalElement (gid)) {
2007+ const local_ordinal_type lc = colmapHost->getLocalElement (gid);
2008+ TEUCHOS_TEST_FOR_EXCEPT_MSG (lc == Teuchos::OrdinalTraits<local_ordinal_type>::invalid (),
2009+ BlockHelperDetails::get_msg_prefix (comm) << " GID " << gid
2010+ << " gives an invalid local column." );
2011+ }
2012+ }
2013+ }
2014+ #endif
2015+ auto rowmap = g->getRowMap ()->getLocalMap ();
2016+ auto colmap = g->getColMap ()->getLocalMap ();
2017+ auto dommap = g->getDomainMap ()->getLocalMap ();
2018+
2019+ const Kokkos::RangePolicy<execution_space> policy (0 , nrows);
20032020 Kokkos::parallel_for (
20042021 " performSymbolicPhase::RangePolicy::col2row" ,
20052022 policy, KOKKOS_LAMBDA (const local_ordinal_type &lr) {
2006- const global_ordinal_type gid = rowmap->getGlobalElement (lr);
2007- TEUCHOS_ASSERT (gid != Teuchos::OrdinalTraits<global_ordinal_type>::invalid ());
2008- if (dommap->isNodeGlobalElement (gid)) {
2009- const local_ordinal_type lc = colmap->getLocalElement (gid);
2010- #if defined(BLOCKTRIDICONTAINER_DEBUG)
2011- TEUCHOS_TEST_FOR_EXCEPT_MSG (lc == Teuchos::OrdinalTraits<local_ordinal_type>::invalid (),
2012- BlockHelperDetails::get_msg_prefix (comm) << " GID " << gid
2013- << " gives an invalid local column." );
2014- #endif
2015- col2row (lc) = lr;
2023+ const global_ordinal_type gid = rowmap.getGlobalElement (lr);
2024+ if (dommap.getLocalElement (gid) != lo_traits::invalid ()) {
2025+ const local_ordinal_type lc = colmap.getLocalElement (gid);
2026+ col2row (lc) = lr;
20162027 }
20172028 });
2018- #endif
20192029 }
20202030
20212031 // construct the D and R graphs in A = D + R.
20222032 {
2023- const auto local_graph = g->getLocalGraphHost ();
2033+ const auto local_graph = g->getLocalGraphDevice ();
20242034 const auto local_graph_rowptr = local_graph.row_map ;
20252035 TEUCHOS_ASSERT (local_graph_rowptr.size () == static_cast <size_t >(nrows + 1 ));
20262036 const auto local_graph_colidx = local_graph.entries ;
20272037
20282038 // assume no overlap.
20292039
2030- Kokkos::View<local_ordinal_type *, host_execution_space > lclrow2idx (" lclrow2idx" , nrows);
2040+ Kokkos::View<local_ordinal_type *, execution_space > lclrow2idx (" lclrow2idx" , nrows);
20312041 {
2032- const Kokkos::RangePolicy<host_execution_space > policy (0 , nrows);
2042+ const Kokkos::RangePolicy<execution_space > policy (0 , nrows);
20332043 Kokkos::parallel_for (
20342044 " performSymbolicPhase::RangePolicy::lclrow2idx" ,
20352045 policy, KOKKOS_LAMBDA (const local_ordinal_type &i) {
2036- lclrow2idx[ lclrow (i)] = i;
2046+ lclrow2idx ( lclrow (i)) = i;
20372047 });
20382048 }
20392049
20402050 // count (block) nnzs in D and R.
2041- typedef BlockHelperDetails::SumReducer<size_type, 3 , host_execution_space> sum_reducer_type;
2042- typename sum_reducer_type::value_type sum_reducer_value;
2051+ size_type D_nnz, R_nnz_owned, R_nnz_remote;
20432052 {
2044- const Kokkos::RangePolicy<host_execution_space > policy (0 , nrows);
2053+ const Kokkos::RangePolicy<execution_space > policy (0 , nrows);
20452054 Kokkos::parallel_reduce
20462055 // profiling interface does not work
20472056 ( // "performSymbolicPhase::RangePolicy::count_nnz",
2048- policy, KOKKOS_LAMBDA (const local_ordinal_type &lr, typename sum_reducer_type::value_type &update ) {
2057+ policy, KOKKOS_LAMBDA (const local_ordinal_type &lr, size_type &update_D_nnz, size_type &update_R_nnz_owned, size_type &update_R_nnz_remote ) {
20492058 // LID -> index.
2050- const local_ordinal_type ri0 = lclrow2idx[lr] ;
2059+ const local_ordinal_type ri0 = lclrow2idx (lr) ;
20512060 const local_ordinal_type pi0 = rowidx2part (ri0);
20522061 for (size_type j = local_graph_rowptr (lr); j < local_graph_rowptr (lr + 1 ); ++j) {
20532062 const local_ordinal_type lc = local_graph_colidx (j);
2054- const local_ordinal_type lc2r = col2row[lc] ;
2063+ const local_ordinal_type lc2r = col2row (lc) ;
20552064 bool incr_R = false ;
20562065 do { // breakable
20572066 if (lc2r == (local_ordinal_type)-1 ) {
20582067 incr_R = true ;
20592068 break ;
20602069 }
2061- const local_ordinal_type ri = lclrow2idx[ lc2r] ;
2070+ const local_ordinal_type ri = lclrow2idx ( lc2r) ;
20622071 const local_ordinal_type pi = rowidx2part (ri);
20632072 if (pi != pi0) {
20642073 incr_R = true ;
@@ -2068,23 +2077,20 @@ void performSymbolicPhase(const Teuchos::RCP<const typename BlockHelperDetails::
20682077 // LID space, tridiag LIDs in a row are not necessarily related by
20692078 // {-1, 0, 1}.
20702079 if (ri0 + 1 >= ri && ri0 <= ri + 1 )
2071- ++update. v [ 0 ]; // D_nnz
2080+ ++update_D_nnz;
20722081 else
20732082 incr_R = true ;
20742083 } while (0 );
20752084 if (incr_R) {
20762085 if (lc < nrows)
2077- ++update. v [ 1 ]; // R_nnz_owned
2086+ ++update_R_nnz_owned;
20782087 else
2079- ++update. v [ 2 ]; // R_nnz_remote
2088+ ++update_R_nnz_remote;
20802089 }
20812090 }
20822091 },
2083- sum_reducer_type (sum_reducer_value) );
2092+ D_nnz, R_nnz_owned, R_nnz_remote );
20842093 }
2085- size_type D_nnz = sum_reducer_value.v [0 ];
2086- size_type R_nnz_owned = sum_reducer_value.v [1 ];
2087- size_type R_nnz_remote = sum_reducer_value.v [2 ];
20882094
20892095 if (!overlap_communication_and_computation) {
20902096 R_nnz_owned += R_nnz_remote;
@@ -2093,10 +2099,10 @@ void performSymbolicPhase(const Teuchos::RCP<const typename BlockHelperDetails::
20932099
20942100 // construct the D_00 graph.
20952101 {
2096- const auto flat_td_ptr = Kokkos::create_mirror_view_and_copy ( Kokkos::HostSpace (), btdm.flat_td_ptr ) ;
2102+ const auto flat_td_ptr = btdm.flat_td_ptr ;
20972103
20982104 btdm.A_colindsub = local_ordinal_type_1d_view (" btdm.A_colindsub" , D_nnz);
2099- const auto D_A_colindsub = Kokkos::create_mirror_view ( btdm.A_colindsub ) ;
2105+ const auto D_A_colindsub = btdm.A_colindsub ;
21002106
21012107#if defined(BLOCKTRIDICONTAINER_DEBUG)
21022108 Kokkos::deep_copy (D_A_colindsub, Teuchos::OrdinalTraits<local_ordinal_type>::invalid ());
@@ -2105,9 +2111,9 @@ void performSymbolicPhase(const Teuchos::RCP<const typename BlockHelperDetails::
21052111 const local_ordinal_type nparts = partptr.extent (0 ) - 1 ;
21062112
21072113 {
2108- const Kokkos::RangePolicy<host_execution_space > policy (0 , nparts);
2114+ const Kokkos::RangePolicy<execution_space > policy (0 , nparts);
21092115 Kokkos::parallel_for (
2110- " performSymbolicPhase::RangePolicy<host_execution_space >::D_graph" ,
2116+ " performSymbolicPhase::RangePolicy<execution_space >::D_graph" ,
21112117 policy, KOKKOS_LAMBDA (const local_ordinal_type &pi0) {
21122118 const local_ordinal_type part_ri0 = part2rowidx0 (pi0);
21132119 local_ordinal_type offset = 0 ;
@@ -2131,10 +2137,12 @@ void performSymbolicPhase(const Teuchos::RCP<const typename BlockHelperDetails::
21312137 });
21322138 }
21332139#if defined(BLOCKTRIDICONTAINER_DEBUG)
2134- for (size_t i = 0 ; i < D_A_colindsub.extent (0 ); ++i)
2135- TEUCHOS_ASSERT (D_A_colindsub (i) != Teuchos::OrdinalTraits<local_ordinal_type>::invalid ());
2140+ {
2141+ auto D_A_colindsub_host = Kokkos::create_mirror_view_and_copy (Kokkos::HostSpace (), D_A_colindsub);
2142+ for (size_t i = 0 ; i < D_A_colindsub_host.extent (0 ); ++i)
2143+ TEUCHOS_ASSERT (D_A_colindsub_host (i) != Teuchos::OrdinalTraits<local_ordinal_type>::invalid ());
2144+ }
21362145#endif
2137- Kokkos::deep_copy (btdm.A_colindsub , D_A_colindsub);
21382146
21392147 // Allocate values.
21402148 {
@@ -2157,19 +2165,19 @@ void performSymbolicPhase(const Teuchos::RCP<const typename BlockHelperDetails::
21572165 amd.rowptr = size_type_1d_view (" amd.rowptr" , nrows + 1 );
21582166 amd.A_colindsub = local_ordinal_type_1d_view (do_not_initialize_tag (" amd.A_colindsub" ), R_nnz_owned);
21592167
2160- const auto R_rowptr = Kokkos::create_mirror_view ( amd.rowptr ) ;
2161- const auto R_A_colindsub = Kokkos::create_mirror_view ( amd.A_colindsub ) ;
2168+ const auto R_rowptr = amd.rowptr ;
2169+ const auto R_A_colindsub = amd.A_colindsub ;
21622170
21632171 amd.rowptr_remote = size_type_1d_view (" amd.rowptr_remote" , overlap_communication_and_computation ? nrows + 1 : 0 );
21642172 amd.A_colindsub_remote = local_ordinal_type_1d_view (do_not_initialize_tag (" amd.A_colindsub_remote" ), R_nnz_remote);
21652173
2166- const auto R_rowptr_remote = Kokkos::create_mirror_view ( amd.rowptr_remote ) ;
2167- const auto R_A_colindsub_remote = Kokkos::create_mirror_view ( amd.A_colindsub_remote ) ;
2174+ const auto R_rowptr_remote = amd.rowptr_remote ;
2175+ const auto R_A_colindsub_remote = amd.A_colindsub_remote ;
21682176
21692177 {
2170- const Kokkos::RangePolicy<host_execution_space > policy (0 , nrows);
2178+ const Kokkos::RangePolicy<execution_space > policy (0 , nrows);
21712179 Kokkos::parallel_for (
2172- " performSymbolicPhase::RangePolicy<host_execution_space >::R_graph_count" ,
2180+ " performSymbolicPhase::RangePolicy<execution_space >::R_graph_count" ,
21732181 policy, KOKKOS_LAMBDA (const local_ordinal_type &lr) {
21742182 const local_ordinal_type ri0 = lclrow2idx[lr];
21752183 const local_ordinal_type pi0 = rowidx2part (ri0);
@@ -2193,59 +2201,48 @@ void performSymbolicPhase(const Teuchos::RCP<const typename BlockHelperDetails::
21932201 }
21942202 });
21952203 }
2196-
2197- // exclusive scan
2198- typedef BlockHelperDetails::ArrayValueType<size_type, 2 > update_type ;
2204+ // Prefix-sums to finish computing R_rowptr and R_rowptr_remote
2205+ Kokkos::Experimental::exclusive_scan ( execution_space (), R_rowptr, R_rowptr, size_type ( 0 ));
2206+ Kokkos::Experimental::exclusive_scan ( execution_space (), R_rowptr_remote, R_rowptr_remote, size_type ( 0 )) ;
21992207 {
2200- Kokkos::RangePolicy<host_execution_space> policy (0 , nrows + 1 );
2201- Kokkos::parallel_scan (
2202- " performSymbolicPhase::RangePolicy<host_execution_space>::R_graph_fill" ,
2203- policy, KOKKOS_LAMBDA (const local_ordinal_type &lr, update_type &update, const bool &final ) {
2204- update_type val;
2205- val.v [0 ] = R_rowptr (lr);
2206- if (overlap_communication_and_computation)
2207- val.v [1 ] = R_rowptr_remote (lr);
2208-
2209- if (final ) {
2210- R_rowptr (lr) = update.v [0 ];
2211- if (overlap_communication_and_computation)
2212- R_rowptr_remote (lr) = update.v [1 ];
2213-
2214- if (lr < nrows) {
2215- const local_ordinal_type ri0 = lclrow2idx[lr];
2216- const local_ordinal_type pi0 = rowidx2part (ri0);
2217-
2218- size_type cnt_rowptr = R_rowptr (lr);
2219- size_type cnt_rowptr_remote = overlap_communication_and_computation ? R_rowptr_remote (lr) : 0 ; // when not overlap_communication_and_computation, this value is garbage
2220-
2221- const size_type j0 = local_graph_rowptr (lr);
2222- for (size_type j = j0; j < local_graph_rowptr (lr + 1 ); ++j) {
2223- const local_ordinal_type lc = local_graph_colidx (j);
2224- const local_ordinal_type lc2r = col2row[lc];
2225- if (lc2r != (local_ordinal_type)-1 ) {
2226- const local_ordinal_type ri = lclrow2idx[lc2r];
2227- const local_ordinal_type pi = rowidx2part (ri);
2228- if (pi == pi0 && ri + 1 >= ri0 && ri <= ri0 + 1 )
2229- continue ;
2230- }
2231- const local_ordinal_type row_entry = j - j0;
2232- if (!overlap_communication_and_computation || lc < nrows)
2233- R_A_colindsub (cnt_rowptr++) = row_entry;
2234- else
2235- R_A_colindsub_remote (cnt_rowptr_remote++) = row_entry;
2236- }
2208+ // Fill R graph entries (R_A_colindsub and R_A_colindsub_remote)
2209+ Kokkos::RangePolicy<execution_space> policy (0 , nrows);
2210+ Kokkos::parallel_for (
2211+ " performSymbolicPhase::RangePolicy<execution_space>::R_graph_fill" ,
2212+ policy, KOKKOS_LAMBDA (const local_ordinal_type &lr) {
2213+ const local_ordinal_type ri0 = lclrow2idx[lr];
2214+ const local_ordinal_type pi0 = rowidx2part (ri0);
2215+
2216+ size_type cnt_rowptr = R_rowptr (lr);
2217+ size_type cnt_rowptr_remote = overlap_communication_and_computation ? R_rowptr_remote (lr) : 0 ; // when not overlap_communication_and_computation, this value is garbage
2218+
2219+ const size_type j0 = local_graph_rowptr (lr);
2220+ for (size_type j = j0; j < local_graph_rowptr (lr + 1 ); ++j) {
2221+ const local_ordinal_type lc = local_graph_colidx (j);
2222+ const local_ordinal_type lc2r = col2row[lc];
2223+ if (lc2r != (local_ordinal_type)-1 ) {
2224+ const local_ordinal_type ri = lclrow2idx[lc2r];
2225+ const local_ordinal_type pi = rowidx2part (ri);
2226+ if (pi == pi0 && ri + 1 >= ri0 && ri <= ri0 + 1 )
2227+ continue ;
22372228 }
2229+ const local_ordinal_type row_entry = j - j0;
2230+ if (!overlap_communication_and_computation || lc < nrows)
2231+ R_A_colindsub (cnt_rowptr++) = row_entry;
2232+ else
2233+ R_A_colindsub_remote (cnt_rowptr_remote++) = row_entry;
22382234 }
2239- update += val;
22402235 });
22412236 }
2242- TEUCHOS_ASSERT (R_rowptr (nrows) == R_nnz_owned);
2243- Kokkos::deep_copy (amd.rowptr , R_rowptr);
2244- Kokkos::deep_copy (amd.A_colindsub , R_A_colindsub);
2245- if (overlap_communication_and_computation) {
2246- TEUCHOS_ASSERT (R_rowptr_remote (nrows) == R_nnz_remote);
2247- Kokkos::deep_copy (amd.rowptr_remote , R_rowptr_remote);
2248- Kokkos::deep_copy (amd.A_colindsub_remote , R_A_colindsub_remote);
2237+ {
2238+ // Check that the last elements of R_rowptr (aka amd.rowptr)
2239+ // and R_rowptr_remote (aka amd.rowptr_remote) match the expected entry counts
2240+ auto r_rowptr_end = Kokkos::create_mirror_view_and_copy (Kokkos::HostSpace (), Kokkos::subview (R_rowptr, nrows));
2241+ TEUCHOS_ASSERT (r_rowptr_end () == R_nnz_owned);
2242+ if (overlap_communication_and_computation) {
2243+ auto r_rowptr_remote_end = Kokkos::create_mirror_view_and_copy (Kokkos::HostSpace (), Kokkos::subview (R_rowptr_remote, nrows));
2244+ TEUCHOS_ASSERT (r_rowptr_remote_end () == R_nnz_remote);
2245+ }
22492246 }
22502247
22512248 // Allocate or view values.
0 commit comments