diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h index bcc873bfb03180..712663cfd88bbe 100644 --- a/tensorflow/core/kernels/gpu_prim.h +++ b/tensorflow/core/kernels/gpu_prim.h @@ -82,8 +82,8 @@ namespace gpuprim = ::hipcub; // Required for sorting Eigen::half and bfloat16. namespace rocprim { +#if (TF_ROCM_VERSION >= 50200 && TF_ROCM_VERSION < 70000) namespace detail { -#if (TF_ROCM_VERSION >= 50200) template <> struct float_bit_mask { static constexpr uint16_t sign_bit = 0x8000; @@ -99,7 +99,27 @@ struct float_bit_mask { static constexpr uint16_t mantissa = 0x007F; using bit_type = uint16_t; }; +}; // namespace detail + +#else +namespace traits { +template<> +struct rocprim::traits::define { + using float_bit_mask = rocprim::traits::float_bit_mask::values; + using is_arithmetic = rocprim::traits::is_arithmetic::values; + using number_format = rocprim::traits::number_format::values; +}; + +template<> +struct rocprim::traits::define { + using float_bit_mask = rocprim::traits::float_bit_mask::values; + using is_arithmetic = rocprim::traits::is_arithmetic::values; + using number_format = rocprim::traits::number_format::values; +}; +}; // namespace traits #endif +#if (TF_ROCM_VERSION < 70000) +namespace detail { template <> struct radix_key_codec_base : radix_key_codec_floating {}; @@ -107,6 +127,7 @@ template <> struct radix_key_codec_base : radix_key_codec_floating {}; }; // namespace detail +#endif }; // namespace rocprim #endif // TENSORFLOW_USE_ROCM