1414* limitations under the License.
1515*******************************************************************************/
1616
17- #include " xpu/ocl/engine_impl.hpp"
18- #include " xpu/ocl/stream_impl.hpp"
19-
2017#include " gpu/intel/ocl/usm_utils.hpp"
18+ #include " gpu/intel/ocl/engine.hpp"
2119
2220namespace dnnl {
2321namespace impl {
@@ -42,6 +40,27 @@ cl_command_queue get_ocl_queue(impl::stream_t *stream) {
4240 return utils::downcast<xpu::ocl::stream_impl_t *>(stream->impl ())->queue ();
4341}
4442
43+ const compute::device_info_t *get_device_info (impl::engine_t *engine) {
44+ return utils::downcast<const ocl::engine_t *>(engine)->device_info ();
45+ }
46+
47+ template <typename F>
48+ void *usm_malloc_common (impl::engine_t *engine, size_t size, F ext_func) {
49+ auto device_info = get_device_info (engine);
50+
51+ if (size == 0 || size > device_info->memory_size ()) return nullptr ;
52+ bool large_buffer = size > device_info->max_allocation_size ();
53+ cl_bitfield large_buffer_flag[]
54+ = {CL_MEM_FLAGS_INTEL, CL_MEM_ALLOW_UNRESTRICTED_SIZE_INTEL, 0 };
55+
56+ cl_int err;
57+ void *p = ext_func (engine, get_ocl_context (engine), get_ocl_device (engine),
58+ large_buffer ? large_buffer_flag : nullptr , size, 0 , &err);
59+ assert (utils::one_of (err, CL_SUCCESS, CL_OUT_OF_RESOURCES,
60+ CL_OUT_OF_HOST_MEMORY, CL_INVALID_BUFFER_SIZE));
61+ return p;
62+ }
63+
4564} // namespace
4665
4766bool is_usm_supported (impl::engine_t *engine) {
@@ -55,13 +74,18 @@ bool is_usm_supported(impl::engine_t *engine) {
5574void *malloc_host (impl::engine_t *engine, size_t size) {
5675 using clHostMemAllocINTEL_func_t = void *(*)(cl_context, const cl_ulong *,
5776 size_t , cl_uint, cl_int *);
58-
59- if (size == 0 ) return nullptr ;
60-
6177 static xpu::ocl::ext_func_t <clHostMemAllocINTEL_func_t> ext_func (
6278 " clHostMemAllocINTEL" );
79+ auto device_info = get_device_info (engine);
80+
81+ if (size == 0 || size > device_info->memory_size ()) return nullptr ;
82+ bool large_buffer = size > device_info->max_allocation_size ();
83+ cl_bitfield large_buffer_flag[]
84+ = {CL_MEM_FLAGS_INTEL, CL_MEM_ALLOW_UNRESTRICTED_SIZE_INTEL, 0 };
85+
6386 cl_int err;
64- void *p = ext_func (engine, get_ocl_context (engine), nullptr , size, 0 , &err);
87+ void *p = ext_func (engine, get_ocl_context (engine),
88+ large_buffer ? large_buffer_flag : nullptr , size, 0 , &err);
6589 assert (utils::one_of (err, CL_SUCCESS, CL_OUT_OF_RESOURCES,
6690 CL_OUT_OF_HOST_MEMORY, CL_INVALID_BUFFER_SIZE));
6791 return p;
@@ -70,38 +94,21 @@ void *malloc_host(impl::engine_t *engine, size_t size) {
7094void *malloc_device (impl::engine_t *engine, size_t size) {
7195 using clDeviceMemAllocINTEL_func_t = void *(*)(cl_context, cl_device_id,
7296 cl_ulong *, size_t , cl_uint, cl_int *);
73-
74- if (size == 0 ) return nullptr ;
75-
7697 static xpu::ocl::ext_func_t <clDeviceMemAllocINTEL_func_t> ext_func (
7798 " clDeviceMemAllocINTEL" );
78- cl_int err;
79- void *p = ext_func (engine, get_ocl_context (engine), get_ocl_device (engine),
80- nullptr , size, 0 , &err);
81- assert (utils::one_of (err, CL_SUCCESS, CL_OUT_OF_RESOURCES,
82- CL_OUT_OF_HOST_MEMORY, CL_INVALID_BUFFER_SIZE));
83- return p;
99+ return usm_malloc_common (engine, size, ext_func);
84100}
85101
86102void *malloc_shared (impl::engine_t *engine, size_t size) {
87103 using clSharedMemAllocINTEL_func_t = void *(*)(cl_context, cl_device_id,
88104 cl_ulong *, size_t , cl_uint, cl_int *);
89-
90- if (size == 0 ) return nullptr ;
91-
92105 static xpu::ocl::ext_func_t <clSharedMemAllocINTEL_func_t> ext_func (
93106 " clSharedMemAllocINTEL" );
94- cl_int err;
95- void *p = ext_func (engine, get_ocl_context (engine), get_ocl_device (engine),
96- nullptr , size, 0 , &err);
97- assert (utils::one_of (err, CL_SUCCESS, CL_OUT_OF_RESOURCES,
98- CL_OUT_OF_HOST_MEMORY, CL_INVALID_BUFFER_SIZE));
99- return p;
107+ return usm_malloc_common (engine, size, ext_func);
100108}
101109
102110void free (impl::engine_t *engine, void *ptr) {
103111 using clMemFreeINTEL_func_t = cl_int (*)(cl_context, void *);
104-
105112 if (!ptr) return ;
106113 static xpu::ocl::ext_func_t <clMemFreeINTEL_func_t> ext_func (
107114 " clMemFreeINTEL" );
0 commit comments