Skip to content

Commit 3754b83

Browse files
authored
Merge pull request #14700 from crtrott/stokhos-stuff
Refactor Stokhos to support Kokkos 5.0
2 parents 51aa768 + 7bc915c commit 3754b83

15 files changed

+359
-0
lines changed

packages/sacado/test/UnitTests/Fad_KokkosTests.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,9 @@ TEUCHOS_UNIT_TEST_TEMPLATE_3_DECL(
14721472

14731473
// Check dimensions are correct
14741474
TEUCHOS_TEST_EQUALITY(Kokkos::dimension_scalar(v2), fad_size+1, out, success);
1475+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
1476+
TEUCHOS_TEST_EQUALITY(v2.stride(0), v1.stride(0), out, success);
1477+
#endif
14751478

14761479
// Check values
14771480
FadType f =
@@ -1509,6 +1512,9 @@ TEUCHOS_UNIT_TEST_TEMPLATE_3_DECL(
15091512
TEUCHOS_TEST_EQUALITY(v2.extent(0), num_rows, out, success);
15101513
TEUCHOS_TEST_EQUALITY(Kokkos::dimension_scalar(v2), fad_size+1, out, success);
15111514
TEUCHOS_TEST_EQUALITY(v2.stride(0), v1.stride(0), out, success);
1515+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
1516+
TEUCHOS_TEST_EQUALITY(v2.stride(1), v1.stride(1), out, success);
1517+
#endif
15121518

15131519
// Check values
15141520
for (size_type i=0; i<num_rows; ++i) {
@@ -1550,6 +1556,9 @@ TEUCHOS_UNIT_TEST_TEMPLATE_3_DECL(
15501556
TEUCHOS_TEST_EQUALITY(Kokkos::dimension_scalar(v2), fad_size+1, out, success);
15511557
TEUCHOS_TEST_EQUALITY(v2.stride(0), v1.stride(0), out, success);
15521558
TEUCHOS_TEST_EQUALITY(v2.stride(1), v1.stride(1), out, success);
1559+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
1560+
TEUCHOS_TEST_EQUALITY(v2.stride(2), v1.stride(2), out, success);
1561+
#endif
15531562

15541563
// Check values
15551564
for (size_type i=0; i<num_rows; ++i) {

packages/stokhos/src/sacado/kokkos/vector/KokkosExp_View_MP_Vector_Contiguous.hpp

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,201 @@
2828
#include "Kokkos_View_Utils.hpp"
2929
#include "Kokkos_View_MP_Vector_Utils.hpp"
3030

31+
#ifndef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
32+
#include "Sacado.hpp"
33+
#include "Kokkos_DualView.hpp"
3134
//----------------------------------------------------------------------------
3235

36+
namespace Stokhos {
37+
38+
template<class T>
39+
struct is_mp_vector {
40+
static constexpr bool value = false;
41+
};
42+
43+
template<class I, class T, int N, class Exec>
44+
struct is_mp_vector<Sacado::MP::Vector<Stokhos::StaticFixedStorage<I, T, N, Exec>>> {
45+
static constexpr bool value = true;
46+
};
47+
48+
template<class I, class T, int N, class Exec>
49+
struct is_mp_vector<const Sacado::MP::Vector<Stokhos::StaticFixedStorage<I, T, N, Exec>>> {
50+
static constexpr bool value = true;
51+
};
52+
53+
template<class T>
54+
static constexpr bool is_mp_vector_v = is_mp_vector<T>::value;
55+
56+
}
57+
58+
namespace Kokkos {
59+
60+
template<class DataType, class ... Args>
61+
struct is_view_mp_vector<Kokkos::View<DataType, Args...>> {
62+
static constexpr bool value = Stokhos::is_mp_vector_v<typename Kokkos::View<DataType, Args...>::value_type>;
63+
};
64+
template <typename T, typename ... P>
65+
requires(is_view_mp_vector< View<T,P...> >::value)
66+
KOKKOS_INLINE_FUNCTION
67+
constexpr unsigned dimension_scalar(const View<T,P...>& view) {
68+
return typename View<T,P...>::value_type().size();
69+
}
70+
71+
}
72+
73+
74+
namespace Stokhos {
75+
76+
template<class DataType, class ... Args>
77+
requires(!is_mp_vector_v<typename Kokkos::View<DataType, Args...>::value_type>)
78+
KOKKOS_FUNCTION
79+
auto reinterpret_as_unmanaged_scalar_view(const Kokkos::View<DataType, Args...>& view) {
80+
return view;
81+
}
82+
83+
template<class DataType, class ... Args>
84+
requires(is_mp_vector_v<typename Kokkos::View<DataType, Args...>::value_type>)
85+
KOKKOS_FUNCTION
86+
auto reinterpret_as_unmanaged_scalar_view(const Kokkos::View<DataType, Args...>& view) {
87+
using view_t = Kokkos::View<DataType, Args...>;
88+
using scalar_t = typename view_t::value_type::value_type;
89+
constexpr bool is_layout_right = std::is_same_v<typename view_t::layout_type, Kokkos::LayoutRight> ||
90+
std::is_same_v<typename view_t::layout_type, Kokkos::Experimental::layout_right_padded<Kokkos::dynamic_extent>>;
91+
if constexpr (view_t::rank() == 0) {
92+
return Kokkos::View<scalar_t*, Args...>(reinterpret_cast<scalar_t*>(view.data()), Kokkos::dimension_scalar(view));
93+
} else
94+
if constexpr (view_t::rank() == 1) {
95+
if constexpr (is_layout_right) {
96+
return Kokkos::View<scalar_t**, Args...>(reinterpret_cast<scalar_t*>(view.data()), view.extent(0), Kokkos::dimension_scalar(view));
97+
} else {
98+
return Kokkos::View<scalar_t**, Args...>(reinterpret_cast<scalar_t*>(view.data()), Kokkos::dimension_scalar(view), view.extent(0));
99+
}
100+
} else
101+
if constexpr (view_t::rank() == 2) {
102+
if constexpr (is_layout_right) {
103+
return Kokkos::View<scalar_t***, Args...>(reinterpret_cast<scalar_t*>(view.data()), view.extent(0), view.extent(1), Kokkos::dimension_scalar(view));
104+
} else {
105+
return Kokkos::View<scalar_t***, Args...>(reinterpret_cast<scalar_t*>(view.data()), Kokkos::dimension_scalar(view), view.extent(0), view.extent(1));
106+
}
107+
} else
108+
if constexpr (view_t::rank() == 3) {
109+
if constexpr (is_layout_right) {
110+
return Kokkos::View<scalar_t****, Args...>(reinterpret_cast<scalar_t*>(view.data()), view.extent(0), view.extent(1), view.extent(2), Kokkos::dimension_scalar(view));
111+
} else {
112+
return Kokkos::View<scalar_t****, Args...>(reinterpret_cast<scalar_t*>(view.data()), Kokkos::dimension_scalar(view), view.extent(0), view.extent(1), view.extent(2));
113+
}
114+
} else
115+
if constexpr (view_t::rank() == 4) {
116+
if constexpr (is_layout_right) {
117+
return Kokkos::View<scalar_t****, Args...>(reinterpret_cast<scalar_t*>(view.data()), view.extent(0), view.extent(1), view.extent(2), view.extent(3), Kokkos::dimension_scalar(view));
118+
} else {
119+
return Kokkos::View<scalar_t****, Args...>(reinterpret_cast<scalar_t*>(view.data()), Kokkos::dimension_scalar(view), view.extent(0), view.extent(1), view.extent(2), view.extent(3));
120+
}
121+
}
122+
}
123+
124+
template<class DataType, class ... Args>
125+
requires(!is_mp_vector_v<typename Kokkos::View<DataType, Args...>::value_type>)
126+
KOKKOS_FUNCTION
127+
auto reinterpret_as_unmanaged_scalar_flat_view(const Kokkos::View<DataType, Args...>& view) {
128+
return view;
129+
}
130+
131+
template<class DataType, class ... Args>
132+
requires(is_mp_vector_v<typename Kokkos::View<DataType, Args...>::value_type>)
133+
KOKKOS_FUNCTION
134+
auto reinterpret_as_unmanaged_scalar_flat_view(const Kokkos::View<DataType, Args...>& view) {
135+
using view_t = Kokkos::View<DataType, Args...>;
136+
using value_type = typename view_t::value_type::value_type;
137+
using scalar_t = std::conditional_t<std::is_const_v<typename view_t::value_type>, const value_type, value_type>;
138+
return Kokkos::View<scalar_t*, Args...>(reinterpret_cast<scalar_t*>(view.data()), view.mapping().required_span_size() * Kokkos::dimension_scalar(view));
139+
}
140+
141+
template<class T>
142+
using scalar_view_t = decltype(reinterpret_as_unmanaged_scalar_view(std::declval<T>()));
143+
144+
template<class T>
145+
using scalar_flat_view_t = decltype(reinterpret_as_unmanaged_scalar_flat_view(std::declval<T>()));
146+
147+
namespace {
148+
template<class DataType, class ... Args>
149+
struct DualViewModifiedFlagsAccessor: public Kokkos::DualView<DataType, Args...> {
150+
using base_t = Kokkos::DualView<DataType, Args...>;
151+
using base_t::base_t;
152+
auto get_modified_flags() const { return base_t::modified_flags; }
153+
DualViewModifiedFlagsAccessor(
154+
typename base_t::t_modified_flags mod_flags,
155+
typename base_t::t_dev dev,
156+
typename base_t::t_host host):base_t(dev, host) {
157+
base_t::modified_flags = mod_flags;
158+
}
159+
DualViewModifiedFlagsAccessor(base_t dv):base_t(dv) {}
160+
};
161+
}
162+
163+
// DualView reinterpretation for creating flat Tpetra MV
164+
template<class DataType, class ... Args>
165+
requires(!is_mp_vector_v<typename Kokkos::DualView<DataType, Args...>::value_type>)
166+
KOKKOS_FUNCTION
167+
auto reinterpret_as_unmanaged_scalar_dual_view_of_same_rank(const Kokkos::DualView<DataType, Args...>& view) {
168+
return view;
169+
}
170+
171+
template<class DataType, class ... Args>
172+
requires(is_mp_vector_v<typename Kokkos::View<DataType, Args...>::value_type>)
173+
KOKKOS_FUNCTION
174+
auto reinterpret_as_unmanaged_scalar_dual_view_of_same_rank(const Kokkos::DualView<DataType, Args...>& view) {
175+
static_assert(Kokkos::DualView<DataType, Args...>::t_dev::rank() == 2);
176+
177+
using view_t = Kokkos::DualView<DataType, Args...>;
178+
using scalar_t = typename view_t::value_type::value_type;
179+
using view_scalar_t = Kokkos::DualView<scalar_t**, Args...>;
180+
181+
auto view_dev = view.view_device();
182+
auto view_host = view.view_host();
183+
DualViewModifiedFlagsAccessor<DataType, Args...> view_access(view);
184+
printf("Sacado::dimension_scalar: %i %i\n", (int)Sacado::dimension_scalar(view), (int)typename Kokkos::View<DataType, Args...>::value_type().size());
185+
int dim_scalar = typename Kokkos::View<DataType, Args...>::value_type().size();
186+
constexpr bool is_layout_left = std::is_same_v<typename Kokkos::DualView<DataType, Args...>::t_dev::layout_type, Kokkos::LayoutLeft> ||
187+
std::is_same_v<typename Kokkos::DualView<DataType, Args...>::t_dev::layout_type, Kokkos::Experimental::layout_left_padded<Kokkos::dynamic_extent>>;
188+
if constexpr (is_layout_left) {
189+
DualViewModifiedFlagsAccessor<scalar_t**, Args...> view_scalar(
190+
view_access.get_modified_flags(),
191+
typename view_scalar_t::t_dev(reinterpret_cast<scalar_t*>(view_dev.data()),
192+
view_dev.extent(0) * dim_scalar, view_dev.extent(1)),
193+
typename view_scalar_t::t_host(reinterpret_cast<scalar_t*>(view_host.data()),
194+
view_dev.extent(0) * dim_scalar, view_dev.extent(1))
195+
);
196+
197+
return view_scalar_t(view_scalar);
198+
} else {
199+
DualViewModifiedFlagsAccessor<scalar_t**, Args...> view_scalar(
200+
view_access.get_modified_flags(),
201+
typename view_scalar_t::t_dev(reinterpret_cast<scalar_t*>(view_dev.data()),
202+
view_dev.extent(0), dim_scalar * view_dev.extent(1)),
203+
typename view_scalar_t::t_host(reinterpret_cast<scalar_t*>(view_host.data()),
204+
view_dev.extent(0), dim_scalar * view_dev.extent(1))
205+
);
206+
207+
return view_scalar_t(view_scalar);
208+
}
209+
}
210+
211+
}
212+
213+
namespace Kokkos {
214+
215+
template <typename D, typename ... P>
216+
struct FlatArrayType< View<D,P...>,
217+
typename std::enable_if< is_view_mp_vector< View<D,P...> >::value >::type > {
218+
typedef View<D,P...> view_type;
219+
typedef typename view_type::value_type::value_type flat_value_type;
220+
using type = Stokhos::scalar_flat_view_t<view_type>;
221+
};
222+
223+
}
224+
225+
#else // KOKKOS_ENABLE_IMPL_VIEW_LEGACY
33226
namespace Kokkos {
34227
namespace Experimental {
35228
namespace Impl {
@@ -1847,6 +2040,7 @@ class ViewMapping< DstTraits , SrcTraits ,
18472040

18482041
} // namespace Impl
18492042
} // namespace Kokkos
2043+
#endif
18502044

18512045
//----------------------------------------------------------------------------
18522046
//----------------------------------------------------------------------------

packages/stokhos/src/sacado/kokkos/vector/amesos2/Amesos2_Solver_MP_Vector.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "Amesos2_Factory.hpp"
1515
#include "Stokhos_Sacado_Kokkos_MP_Vector.hpp"
1616
#include "Stokhos_Tpetra_Utilities_MP_Vector.hpp"
17+
#include "KokkosExp_View_MP_Vector_Contiguous.hpp"
1718

1819
namespace Amesos2 {
1920

packages/stokhos/src/sacado/kokkos/vector/ifpack2/Stokhos_Ifpack2_MP_Vector.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,13 @@ struct LocalReciprocalThreshold<
3838
"LocalReciprocalThreshold not implemented for non-constant minVal");
3939
}
4040

41+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
4142
typedef typename Kokkos::FlatArrayType<XV>::type Flat_XV;
4243
Flat_XV flat_X = X;
44+
#else
45+
using Flat_XV = Stokhos::scalar_flat_view_t<XV>;
46+
Flat_XV flat_X = Stokhos::reinterpret_as_unmanaged_scalar_flat_view(X);
47+
#endif
4348
LocalReciprocalThreshold< Flat_XV, SizeType >::compute( flat_X,
4449
minVal.coeff(0) );
4550
}

packages/stokhos/src/sacado/kokkos/vector/linalg/Kokkos_CrsMatrix_MP_Vector.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,11 @@ spmv(
542542
typedef Kokkos::View< OutputType, OutputP... > OutputVectorType;
543543
typedef Kokkos::View< InputType, InputP... > InputVectorType;
544544
using input_vector_type = const_type_t<InputVectorType>;
545+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
545546
typedef typename InputVectorType::array_type::non_const_value_type value_type;
547+
#else
548+
typedef std::remove_const_t<typename InputVectorType::element_type::value_type> value_type;
549+
#endif
546550

547551
#if KOKKOSKERNELS_VERSION >= 40199
548552
if(space != ExecutionSpace()) {
@@ -675,7 +679,11 @@ spmv(
675679
typedef Kokkos::View< OutputType, OutputP... > OutputVectorType;
676680
typedef Kokkos::View< InputType, InputP... > InputVectorType;
677681
using input_vector_type = const_type_t<InputVectorType>;
682+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
678683
typedef typename InputVectorType::array_type::non_const_value_type value_type;
684+
#else
685+
typedef std::remove_const_t<typename InputVectorType::element_type::value_type> value_type;
686+
#endif
679687

680688
if (!Sacado::is_constant(a) || !Sacado::is_constant(b)) {
681689
Kokkos::Impl::raise_error(

packages/stokhos/src/sacado/kokkos/vector/mpicomm/Kokkos_TeuchosCommAdapters_MP_Vector.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#include "Sacado_MP_Vector.hpp"
1717
#include "Kokkos_View_MP_Vector.hpp"
1818
#include "Kokkos_TeuchosCommAdapters.hpp"
19+
#ifndef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
20+
#include "Sacado_Fad_Kokkos_View_Support.hpp"
21+
#endif
1922

2023
//----------------------------------------------------------------------------
2124
// Overloads of Teuchos Comm View functions for Sacado::MP::Vector scalar type
@@ -35,7 +38,11 @@ send (const Kokkos::View<D,P...>& sendBuffer,
3538
typedef Kokkos::View<D,P...> view_type;
3639
typedef typename Kokkos::FlatArrayType<view_type>::type flat_array_type;
3740

41+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
3842
flat_array_type array = sendBuffer;
43+
#else
44+
flat_array_type array = Stokhos::reinterpret_as_unmanaged_scalar_flat_view(sendBuffer);
45+
#endif
3946
Ordinal array_count = count * Kokkos::dimension_scalar(sendBuffer);
4047
send(array, array_count, destRank, tag, comm);
4148
}
@@ -52,7 +59,11 @@ ssend (const Kokkos::View<D,P...>& sendBuffer,
5259
typedef Kokkos::View<D,P...> view_type;
5360
typedef typename Kokkos::FlatArrayType<view_type>::type flat_array_type;
5461

62+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
5563
flat_array_type array = sendBuffer;
64+
#else
65+
flat_array_type array = Stokhos::reinterpret_as_unmanaged_scalar_flat_view(sendBuffer);
66+
#endif
5667
Ordinal array_count = count * Kokkos::dimension_scalar(sendBuffer);
5768
ssend(array, array_count, destRank, tag, comm);
5869
}
@@ -69,7 +80,11 @@ readySend (const Kokkos::View<D,P...>& sendBuffer,
6980
typedef Kokkos::View<D,P...> view_type;
7081
typedef typename Kokkos::FlatArrayType<view_type>::type flat_array_type;
7182

83+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
7284
flat_array_type array = sendBuffer;
85+
#else
86+
flat_array_type array = Stokhos::reinterpret_as_unmanaged_scalar_flat_view(sendBuffer);
87+
#endif
7388
Ordinal array_count = count * Kokkos::dimension_scalar(sendBuffer);
7489
readySend(array, array_count, destRank, tag, comm);
7590
}
@@ -85,7 +100,11 @@ isend (const Kokkos::View<D,P...>& sendBuffer,
85100
typedef Kokkos::View<D,P...> view_type;
86101
typedef typename Kokkos::FlatArrayType<view_type>::type flat_array_type;
87102

103+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
88104
flat_array_type array = sendBuffer;
105+
#else
106+
flat_array_type array = Stokhos::reinterpret_as_unmanaged_scalar_flat_view(sendBuffer);
107+
#endif
89108
return isend(array, destRank, tag, comm);
90109
}
91110

@@ -100,7 +119,11 @@ ireceive (const Kokkos::View<D,P...>& recvBuffer,
100119
typedef Kokkos::View<D,P...> view_type;
101120
typedef typename Kokkos::FlatArrayType<view_type>::type flat_array_type;
102121

122+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
103123
flat_array_type array = recvBuffer;
124+
#else
125+
flat_array_type array = Stokhos::reinterpret_as_unmanaged_scalar_flat_view(recvBuffer);
126+
#endif
104127
return ireceive(array, sourceRank, tag, comm);
105128
}
106129

packages/stokhos/src/sacado/kokkos/vector/tpetra/Stokhos_Tpetra_Utilities_MP_Vector.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "Tpetra_MultiVector.hpp"
1616
#include "Tpetra_CrsGraph.hpp"
1717
#include "Tpetra_CrsMatrix.hpp"
18+
#include "KokkosExp_View_MP_Vector_Contiguous.hpp"
1819

1920
namespace Stokhos {
2021

@@ -146,7 +147,15 @@ namespace Stokhos {
146147
using FlatVector = Tpetra::MultiVector<BaseScalar, LocalOrdinal, GlobalOrdinal, Node>;
147148

148149
// Create flattenend view using reshaping conversion copy constructor
150+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
149151
typename FlatVector::wrapped_dual_view_type flat_vals(vec.getWrappedDualView());
152+
#else
153+
using flat_dual_view_t = typename FlatVector::wrapped_dual_view_type::DVT;
154+
auto w_d_v = vec.getWrappedDualView();
155+
typename FlatVector::wrapped_dual_view_type flat_vals(//vec.getWrappedDualView()
156+
flat_dual_view_t(Stokhos::reinterpret_as_unmanaged_scalar_dual_view_of_same_rank(w_d_v.implGetDualView())),
157+
flat_dual_view_t(Stokhos::reinterpret_as_unmanaged_scalar_dual_view_of_same_rank(w_d_v.implGetOriginalDualView())));
158+
#endif
150159

151160
// Create flat vector
152161
return Teuchos::make_rcp<FlatVector>(flat_map, flat_vals);

packages/stokhos/src/sacado/kokkos/vector/tpetra/Tpetra_KokkosRefactor_Details_MultiVectorLocalDeepCopy_MP_Vector.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "Tpetra_KokkosRefactor_Details_MultiVectorLocalDeepCopy.hpp"
1414
#include "Stokhos_Sacado_Kokkos_MP_Vector.hpp"
1515

16+
#ifdef KOKKOS_ENABLE_IMPL_VIEW_LEGACY
1617
namespace Tpetra {
1718
namespace Details {
1819

@@ -56,5 +57,6 @@ namespace Details {
5657

5758
} // Details namespace
5859
} // Tpetra namespace
60+
#endif // KOKKOS_ENABLE_IMPL_VIEW_LEGACY
5961

6062
#endif // TPETRA_KOKKOS_REFACTOR_DETAILS_MULTI_VECTOR_LOCAL_DEEP_COPY_MP_VECTOR_HPP

0 commit comments

Comments
 (0)