@@ -38,11 +38,12 @@ namespace gpuprim = ::cub;
38
38
#include " rocm/rocm_config.h"
39
39
namespace gpuprim = ::hipcub;
40
40
41
- // Required for sorting Eigen::half and bfloat16.
42
41
namespace rocprim {
42
+
43
+ #if (TF_ROCM_VERSION >= 50200 && TF_ROCM_VERSION < 70000)
44
+ // Required for sorting Eigen::half and bfloat16.
43
45
namespace detail {
44
46
45
- #if (TF_ROCM_VERSION >= 50200)
46
47
template <>
47
48
struct float_bit_mask <Eigen::half> {
48
49
static constexpr uint16_t sign_bit = 0x8000 ;
@@ -58,14 +59,41 @@ struct float_bit_mask<tsl::bfloat16> {
58
59
static constexpr uint16_t mantissa = 0x007F ;
59
60
using bit_type = uint16_t ;
60
61
};
61
- #endif // TF_ROCM_VERSION >= 50200
62
+
63
+ }; // namespace detail
64
+ #else
65
+ namespace traits {
66
+
67
+ template <>
68
+ struct rocprim ::traits::define<Eigen::half> {
69
+ using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t , 0x8000 , 0x7C00 , 0x03FF >;
70
+ using is_arithmetic = rocprim::traits::is_arithmetic::values<true >;
71
+ using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
72
+ };
73
+
74
+ template <>
75
+ struct rocprim ::traits::define<tsl::bfloat16> {
76
+ using float_bit_mask = rocprim::traits::float_bit_mask::values<uint16_t , 0x8000 , 0x7F80 , 0x007F >;
77
+ using is_arithmetic = rocprim::traits::is_arithmetic::values<true >;
78
+ using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
79
+ };
80
+
81
+ }; // namespace traits
82
+ #endif // TF_ROCM_VERSION >= 50200 && TF_ROCM_VERSION < 70000
83
+
84
+ #if (TF_ROCM_VERSION < 70000)
85
+ namespace detail {
86
+
62
87
template <>
63
88
struct radix_key_codec_base <Eigen::half>
64
89
: radix_key_codec_floating<Eigen::half, uint16_t > {};
65
90
template <>
66
91
struct radix_key_codec_base <tsl::bfloat16>
67
92
: radix_key_codec_floating<tsl::bfloat16, uint16_t > {};
93
+
68
94
}; // namespace detail
95
+ #endif // TF_ROCM_VERSION < 70000
96
+
69
97
}; // namespace rocprim
70
98
71
99
#endif // TENSORFLOW_USE_ROCM
0 commit comments