diff --git a/src/TiledArray/conversions/foreach.h b/src/TiledArray/conversions/foreach.h index 20f2d36ec3..4f02387902 100644 --- a/src/TiledArray/conversions/foreach.h +++ b/src/TiledArray/conversions/foreach.h @@ -480,7 +480,8 @@ inline std::enable_if_t, DistArray> foreach ( /// function will fence before AND after the data is modified template ::type>::value>::type> + typename std::decay::type>::value>::type, + typename = typename std::enable_if::value>::type> inline std::enable_if_t, void> foreach_inplace( DistArray& arg, Op&& op, bool fence = true) { // The tile data is being modified in place, which means we may need to @@ -495,6 +496,68 @@ inline std::enable_if_t, void> foreach_inplace( if (fence) arg.world().gop.fence(); } +/// Modify each element of an Array object + +/// This function modifies the elements of a \c DistArray object with a const reference to the +/// index of the current element. This allows the user to modify specific elements of the array +/// based on their indices. Users must provide a function/functor that modifies each element. The provided function +/// should take a reference to a \c Tile object and a reference to a \c std::vector +/// representing the indices of the current element within the tile. For example, +/// to copy the upper triangular elements of a nxnxn array to a c++ vector of size n^3: +/// \code +/// std::vector vec(n*n*n); +/// forall(array, [&vec] (auto& tile, const auto& index) { +/// size_t i = index[0], j = index[1], k = index[2]; +/// if (i <= j && j <= k) { +/// vec[i*n*n+j*n+k] = tile[index]; +/// } else { +/// vec[i*n*n+j*n+k] = 0.0; +/// } +/// }); +/// \endcode +/// Similarly, to set each upper triangular element of a nxnxn array to the square root of values in a c++ vector of size n^3: +/// \code +/// vector vec(n*n*n); +/// std::generate(v.begin(), v.end(), std::rand); +/// forall(array, [&vec] (Tile& tile, index_type& index) { +/// size_t i = index[0], j = index[1], k = index[2]; +/// if (i <= j && j <= k) { +/// tile[index] = std::sqrt(vec[i*n*n+j*n+k]); +/// } else { +/// tile[index] = 0.0; +/// } +/// }); +/// \endcode +/// The expected signature of the element operation is: +/// \code +/// void op(Tile& tile, Range::index_type& index); +/// \endcode +/// \tparam Tile The tile type of \c arg +/// \tparam Policy The policy type of \c arg +/// \tparam Op Mutating element operation +/// \param arg The argument array to be modified +/// \param op The mutating element function +/// \param fence If \c true, this function will fence before and after the data is modified +template ::type>::value>::type, + typename = typename std::enable_if::value>::type> +inline void foreach_inplace( + DistArray& arg, Op&& op, bool fence = true) { + + // wrap Op into a shallow-copy copyable handle + auto op_shared_handle = make_op_shared_handle(std::forward(op)); + + // Use foreach_inplace to iterate over tiles and modify elements + foreach_inplace( + arg, + [op = std::move(op_shared_handle)](Tile& tile) mutable { + for (const Range::index_type& index : tile.range()) + op(tile, index); + }, fence); // Fence before and after the data is modified +} + /// Apply a function to each tile of a sparse Array /// This function uses an \c Array object to generate a new \c Array where the @@ -587,7 +650,8 @@ inline std::enable_if_t, DistArray> foreach ( /// function will fence before AND after the data is modified template ::type>::value>::type> + typename std::decay::type>::value>::type, + typename = typename std::enable_if::value>::type> inline std::enable_if_t, void> foreach_inplace( DistArray& arg, Op&& op, bool fence = true) { // The tile data is being modified in place, which means we may need to @@ -629,7 +693,8 @@ inline std:: } /// This function takes two input tiles and put result into the left tile -template +template ::value>::type> inline std::enable_if_t, void> foreach_inplace( DistArray& left, const DistArray& right, Op&& op, bool fence = true) { @@ -675,7 +740,8 @@ inline std:: } /// This function takes two input tiles and put result into the left tile -template +template ::value>::type> inline std::enable_if_t, void> foreach_inplace( DistArray& left, const DistArray& right, Op&& op, diff --git a/tests/foreach.cpp b/tests/foreach.cpp index 106a855a0a..d4cb1adda6 100644 --- a/tests/foreach.cpp +++ b/tests/foreach.cpp @@ -119,6 +119,26 @@ BOOST_AUTO_TEST_CASE(foreach_unary) { } } +BOOST_AUTO_TEST_CASE(foreach_w_idx) { + + TArrayI result = a.clone(); + foreach_inplace(result, [](TensorI& tile, const Range::index_type &coord_idx) { + long fac = (coord_idx[0] < coord_idx[1]) ? coord_idx[0] : coord_idx[1]; + tile[coord_idx] = fac * tile[coord_idx]; + }, true); + + for (auto index : *result.pmap()) { + TensorI tile0 = a.find(index).get(); + TensorI tile = result.find(index).get(); + const Range &range = tile0.range(); + for (std::size_t i = 0; i < tile.size(); ++i) { + const Range::index_type &coord_idx = range.idx(i); + long fac = coord_idx[0] < coord_idx[1] ? coord_idx[0] : coord_idx[1]; + BOOST_CHECK_EQUAL(tile[i], fac * tile0[i]); + } + } +} + BOOST_AUTO_TEST_CASE(foreach_unary_sparse) { TSpArrayI result = foreach (c, [](TensorI& result, const TensorI& arg) -> float {