Skip to content

Commit d4d77b6

Browse files
i-chaochenmmakevic-amdjayfurmanek
authored
Update rocPrim usage for ROCm7 (#2979) (#2987)
Co-authored-by: mmakevic-amd <[email protected]> Co-authored-by: Jason Furmanek <[email protected]>
1 parent 936d47b commit d4d77b6

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

tensorflow/core/kernels/gpu_prim.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ namespace gpuprim = ::hipcub;
8585

8686
// Required for sorting Eigen::half and bfloat16.
8787
namespace rocprim {
88+
#if (TF_ROCM_VERSION >= 50200 && TF_ROCM_VERSION < 70000)
8889
namespace detail {
89-
#if (TF_ROCM_VERSION >= 50200)
9090
template <>
9191
struct float_bit_mask<Eigen::half> {
9292
static constexpr uint16_t sign_bit = 0x8000;
@@ -102,14 +102,35 @@ struct float_bit_mask<Eigen::bfloat16> {
102102
static constexpr uint16_t mantissa = 0x007F;
103103
using bit_type = uint16_t;
104104
};
105+
}; // namespace detail
106+
107+
#else
108+
namespace traits {
109+
template<>
110+
struct rocprim::traits::define<Eigen::half> {
111+
using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t, 0x8000, 0x7C00, 0x03FF>;
112+
using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
113+
using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
114+
};
115+
116+
template<>
117+
struct rocprim::traits::define<tsl::bfloat16> {
118+
using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>;
119+
using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
120+
using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
121+
};
122+
}; // namespace traits
105123
#endif
124+
#if (TF_ROCM_VERSION < 70000)
125+
namespace detail {
106126
template <>
107127
struct radix_key_codec_base<Eigen::half>
108128
: radix_key_codec_floating<Eigen::half, uint16_t> {};
109129
template <>
110130
struct radix_key_codec_base<tensorflow::bfloat16>
111131
: radix_key_codec_floating<tensorflow::bfloat16, uint16_t> {};
112132
}; // namespace detail
133+
#endif
113134
}; // namespace rocprim
114135

115136
#endif // TENSORFLOW_USE_ROCM

third_party/xla/xla/service/gpu/gpu_prim.h

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ namespace gpuprim = ::cub;
3838
#include "rocm/rocm_config.h"
3939
namespace gpuprim = ::hipcub;
4040

41-
// Required for sorting Eigen::half and bfloat16.
4241
namespace rocprim {
42+
43+
#if (TF_ROCM_VERSION >= 50200 && TF_ROCM_VERSION < 70000)
44+
// Required for sorting Eigen::half and bfloat16.
4345
namespace detail {
4446

45-
#if (TF_ROCM_VERSION >= 50200)
4647
template <>
4748
struct float_bit_mask<Eigen::half> {
4849
static constexpr uint16_t sign_bit = 0x8000;
@@ -58,14 +59,41 @@ struct float_bit_mask<tsl::bfloat16> {
5859
static constexpr uint16_t mantissa = 0x007F;
5960
using bit_type = uint16_t;
6061
};
61-
#endif // TF_ROCM_VERSION >= 50200
62+
63+
}; // namespace detail
64+
#else
65+
namespace traits {
66+
67+
template<>
68+
struct rocprim::traits::define<Eigen::half> {
69+
using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t, 0x8000, 0x7C00, 0x03FF>;
70+
using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
71+
using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
72+
};
73+
74+
template<>
75+
struct rocprim::traits::define<tsl::bfloat16> {
76+
using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>;
77+
using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
78+
using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
79+
};
80+
81+
}; // namespace traits
82+
#endif // TF_ROCM_VERSION >= 50200 && TF_ROCM_VERSION < 70000
83+
84+
#if (TF_ROCM_VERSION < 70000)
85+
namespace detail {
86+
6287
template <>
6388
struct radix_key_codec_base<Eigen::half>
6489
: radix_key_codec_floating<Eigen::half, uint16_t> {};
6590
template <>
6691
struct radix_key_codec_base<tsl::bfloat16>
6792
: radix_key_codec_floating<tsl::bfloat16, uint16_t> {};
93+
6894
}; // namespace detail
95+
#endif // TF_ROCM_VERSION < 70000
96+
6997
}; // namespace rocprim
7098

7199
#endif // TENSORFLOW_USE_ROCM

0 commit comments

Comments
 (0)