diff --git a/unified-runtime/source/adapters/level_zero/kernel.cpp b/unified-runtime/source/adapters/level_zero/kernel.cpp index 06d1366a119b5..8182ad132ed2e 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/kernel.cpp @@ -58,6 +58,8 @@ ur_result_t urKernelGetSuggestedLocalWorkSize( inline ur_result_t KernelSetArgValueHelper( ur_kernel_handle_t Kernel, + /// [in][optional] the native handle of the kernel + ze_kernel_handle_t ZeKernel, /// [in] argument index in range [0, num args - 1] uint32_t ArgIndex, /// [in] size of argument type @@ -81,15 +83,20 @@ inline ur_result_t KernelSetArgValueHelper( } ze_result_t ZeResult = ZE_RESULT_SUCCESS; - if (Kernel->ZeKernelMap.empty()) { - auto ZeKernel = Kernel->ZeKernel; + if (ZeKernel) { ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue, (ZeKernel, ArgIndex, ArgSize, PArgValue)); } else { - for (auto It : Kernel->ZeKernelMap) { - auto ZeKernel = It.second; + if (Kernel->ZeKernelMap.empty()) { + auto ZeKernel = Kernel->ZeKernel; ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue, (ZeKernel, ArgIndex, ArgSize, PArgValue)); + } else { + for (auto It : Kernel->ZeKernelMap) { + auto ZeKernel = It.second; + ZeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue, + (ZeKernel, ArgIndex, ArgSize, PArgValue)); + } } } @@ -143,120 +150,15 @@ inline ur_result_t KernelSetArgMemObjHelper( return UR_RESULT_SUCCESS; } -ur_result_t urEnqueueKernelLaunchWithArgsExp( - /// [in] handle of the queue object - ur_queue_handle_t Queue, - /// [in] handle of the kernel object - ur_kernel_handle_t Kernel, - /// [in] number of dimensions, from 1 to 3, to specify the global and - /// work-group work-items - uint32_t workDim, - /// [in][optional] pointer to an array of workDim unsigned values that - /// specify the offset used to calculate the global ID of a work-item - const size_t *GlobalWorkOffset, - /// [in] pointer to an array of workDim unsigned values that specify the - /// number of global work-items in workDim that will execute the kernel - /// function - const size_t *GlobalWorkSize, - /// [in][optional] pointer to an array of workDim unsigned values that - /// specify the number of local work-items forming a work-group that will - /// execute the kernel function. - /// If nullptr, the runtime implementation will choose the work-group size. - const size_t *LocalWorkSize, - /// [in] size of the event wait list - uint32_t NumArgs, - /// [in][optional][range(0, numArgs)] pointer to a list of kernel arg - /// properties. +// Helper for kernel launch APIs. +static ur_result_t EnqueueKernelLaunchCommon( + ur_queue_handle_t Queue, ur_kernel_handle_t Kernel, uint32_t WorkDim, + const size_t *GlobalWorkOffset, const size_t *GlobalWorkSize, + const size_t *LocalWorkSize, uint32_t NumArgs, const ur_exp_kernel_arg_properties_t *Args, - /// [in] size of the launch prop list uint32_t NumPropsInLaunchPropList, - /// [in][range(0, numPropsInLaunchPropList)] pointer to a list of launch - /// properties const ur_kernel_launch_property_t *LaunchPropList, - uint32_t NumEventsInWaitList, - /// [in][optional][range(0, numEventsInWaitList)] pointer to a list of - /// events that must be complete before the kernel execution. If - /// nullptr, the numEventsInWaitList must be 0, indicating that no wait - /// event. - const ur_event_handle_t *EventWaitList, - /// [in,out][optional] return an event object that identifies this - /// particular kernel execution instance. - ur_event_handle_t *OutEvent) { - { - std::scoped_lock Guard(Kernel->Mutex); - for (uint32_t i = 0; i < NumArgs; i++) { - switch (Args[i].type) { - case UR_EXP_KERNEL_ARG_TYPE_LOCAL: - UR_CALL(KernelSetArgValueHelper(Kernel, Args[i].index, Args[i].size, - nullptr)); - break; - case UR_EXP_KERNEL_ARG_TYPE_VALUE: - UR_CALL(KernelSetArgValueHelper(Kernel, Args[i].index, Args[i].size, - Args[i].value.value)); - break; - case UR_EXP_KERNEL_ARG_TYPE_POINTER: - UR_CALL(KernelSetArgValueHelper(Kernel, Args[i].index, Args[i].size, - &Args[i].value.pointer)); - break; - case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ: { - ur_kernel_arg_mem_obj_properties_t Properties = { - UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES, nullptr, - Args[i].value.memObjTuple.flags}; - UR_CALL(KernelSetArgMemObjHelper(Kernel, Args[i].index, &Properties, - Args[i].value.memObjTuple.hMem)); - break; - } - case UR_EXP_KERNEL_ARG_TYPE_SAMPLER: { - UR_CALL(KernelSetArgValueHelper(Kernel, Args[i].index, Args[i].size, - &Args[i].value.sampler->ZeSampler)); - break; - } - default: - return UR_RESULT_ERROR_INVALID_ENUMERATION; - } - } - } - // Normalize so each dimension has at least one work item - return level_zero::urEnqueueKernelLaunch( - Queue, Kernel, workDim, GlobalWorkOffset, GlobalWorkSize, LocalWorkSize, - NumPropsInLaunchPropList, LaunchPropList, NumEventsInWaitList, - EventWaitList, OutEvent); -} - -ur_result_t urEnqueueKernelLaunch( - /// [in] handle of the queue object - ur_queue_handle_t Queue, - /// [in] handle of the kernel object - ur_kernel_handle_t Kernel, - /// [in] number of dimensions, from 1 to 3, to specify the global and - /// work-group work-items - uint32_t WorkDim, - /// [in][optional] pointer to an array of workDim unsigned values that - /// specify the offset used to calculate the global ID of a work-item - const size_t *GlobalWorkOffset, - /// [in] pointer to an array of workDim unsigned values that specify the - /// number of global work-items in workDim that will execute the kernel - /// function - const size_t *GlobalWorkSize, - /// [in][optional] pointer to an array of workDim unsigned values that - /// specify the number of local work-items forming a work-group that - /// will execute the kernel function. If nullptr, the runtime - /// implementation will choose the work-group size. - const size_t *LocalWorkSize, - /// [in] size of the launch prop list - uint32_t NumPropsInLaunchPropList, - /// [in][range(0, numPropsInLaunchPropList)] pointer to a list of launch - /// properties - const ur_kernel_launch_property_t *LaunchPropList, - /// [in] size of the event wait list - uint32_t NumEventsInWaitList, - /// [in][optional][range(0, numEventsInWaitList)] pointer to a list of - /// events that must be complete before the kernel execution. If - /// nullptr, the numEventsInWaitList must be 0, indicating that no wait - /// event. - const ur_event_handle_t *EventWaitList, - /// [in,out][optional] return an event object that identifies this - /// particular kernel execution instance. + uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList, ur_event_handle_t *OutEvent) { using ZeKernelLaunchFuncT = ze_result_t (*)( ze_command_list_handle_t, ze_kernel_handle_t, const ze_group_count_t *, @@ -285,6 +187,39 @@ ur_result_t urEnqueueKernelLaunch( // Lock automatically releases when this goes out of scope. std::scoped_lock Lock( Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex); + for (uint32_t i = 0; i < NumArgs; i++) { + switch (Args[i].type) { + case UR_EXP_KERNEL_ARG_TYPE_LOCAL: + UR_CALL(KernelSetArgValueHelper(Kernel, ZeKernel, Args[i].index, + Args[i].size, nullptr)); + break; + case UR_EXP_KERNEL_ARG_TYPE_VALUE: + UR_CALL(KernelSetArgValueHelper(Kernel, ZeKernel, Args[i].index, + Args[i].size, Args[i].value.value)); + break; + case UR_EXP_KERNEL_ARG_TYPE_POINTER: + UR_CALL(KernelSetArgValueHelper(Kernel, ZeKernel, Args[i].index, + Args[i].size, &Args[i].value.pointer)); + break; + case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ: { + ur_kernel_arg_mem_obj_properties_t Properties = { + UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES, nullptr, + Args[i].value.memObjTuple.flags}; + UR_CALL(KernelSetArgMemObjHelper(Kernel, Args[i].index, &Properties, + Args[i].value.memObjTuple.hMem)); + break; + } + case UR_EXP_KERNEL_ARG_TYPE_SAMPLER: { + UR_CALL(KernelSetArgValueHelper(Kernel, ZeKernel, Args[i].index, + Args[i].size, + &Args[i].value.sampler->ZeSampler)); + break; + } + default: + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + } + if (GlobalWorkOffset != NULL) { UR_CALL(setKernelGlobalOffset(Queue->Context, ZeKernel, WorkDim, GlobalWorkOffset)); @@ -391,6 +326,90 @@ ur_result_t urEnqueueKernelLaunch( return UR_RESULT_SUCCESS; } +ur_result_t urEnqueueKernelLaunchWithArgsExp( + /// [in] handle of the queue object + ur_queue_handle_t Queue, + /// [in] handle of the kernel object + ur_kernel_handle_t Kernel, + /// [in] number of dimensions, from 1 to 3, to specify the global and + /// work-group work-items + uint32_t workDim, + /// [in] pointer to an array of workDim unsigned values that specify the + /// offset used to calculate the global ID of a work-item + const size_t GlobalWorkOffset[3], + /// [in] pointer to an array of workDim unsigned values that specify the + /// number of global work-items in workDim that will execute the kernel + /// function + const size_t GlobalWorkSize[3], + /// [in][optional] pointer to an array of workDim unsigned values that + /// specify the number of local work-items forming a work-group that + /// will execute the kernel function. If nullptr, the runtime + /// implementation will choose the work-group size. + const size_t LocalWorkSize[3], + /// [in] size of the event wait list + uint32_t NumArgs, const ur_exp_kernel_arg_properties_t *Args, + /// [in] size of the launch prop list + uint32_t NumPropsInLaunchPropList, + /// [in][range(0, numPropsInLaunchPropList)] pointer to a list of launch + /// properties + const ur_kernel_launch_property_t *LaunchPropList, + uint32_t NumEventsInWaitList, + /// [in][optional][range(0, numEventsInWaitList)] pointer to a list of + /// events that must be complete before the kernel execution. If + /// nullptr, the numEventsInWaitList must be 0, indicating that no wait + /// event. + const ur_event_handle_t *EventWaitList, + /// [in,out][optional] return an event object that identifies this + /// particular kernel execution instance. + ur_event_handle_t *OutEvent) { + // Normalize so each dimension has at least one work item + return EnqueueKernelLaunchCommon( + Queue, Kernel, workDim, GlobalWorkOffset, GlobalWorkSize, LocalWorkSize, + NumArgs, Args, NumPropsInLaunchPropList, LaunchPropList, + NumEventsInWaitList, EventWaitList, OutEvent); +} + +ur_result_t urEnqueueKernelLaunch( + /// [in] handle of the queue object + ur_queue_handle_t Queue, + /// [in] handle of the kernel object + ur_kernel_handle_t Kernel, + /// [in] number of dimensions, from 1 to 3, to specify the global and + /// work-group work-items + uint32_t WorkDim, + /// [in][optional] pointer to an array of workDim unsigned values that + /// specify the offset used to calculate the global ID of a work-item + const size_t *GlobalWorkOffset, + /// [in] pointer to an array of workDim unsigned values that specify the + /// number of global work-items in workDim that will execute the kernel + /// function + const size_t *GlobalWorkSize, + /// [in][optional] pointer to an array of workDim unsigned values that + /// specify the number of local work-items forming a work-group that + /// will execute the kernel function. If nullptr, the runtime + /// implementation will choose the work-group size. + const size_t *LocalWorkSize, + /// [in] size of the launch prop list + uint32_t NumPropsInLaunchPropList, + /// [in][range(0, numPropsInLaunchPropList)] pointer to a list of launch + /// properties + const ur_kernel_launch_property_t *LaunchPropList, + /// [in] size of the event wait list + uint32_t NumEventsInWaitList, + /// [in][optional][range(0, numEventsInWaitList)] pointer to a list of + /// events that must be complete before the kernel execution. If + /// nullptr, the numEventsInWaitList must be 0, indicating that no wait + /// event. + const ur_event_handle_t *EventWaitList, + /// [in,out][optional] return an event object that identifies this + /// particular kernel execution instance. + ur_event_handle_t *OutEvent) { + return EnqueueKernelLaunchCommon( + Queue, Kernel, WorkDim, GlobalWorkOffset, GlobalWorkSize, LocalWorkSize, + 0 /* NumArgs */, nullptr /* Args */, NumPropsInLaunchPropList, + LaunchPropList, NumEventsInWaitList, EventWaitList, OutEvent); +} + ur_result_t urEnqueueDeviceGlobalVariableWrite( /// [in] handle of the queue to submit to. ur_queue_handle_t Queue, diff --git a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp index 04e202265d05c..cc0102eb1f80d 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp @@ -990,20 +990,24 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp( "ur_queue_immediate_in_order_t::enqueueKernelLaunchWithArgsExp"); { std::scoped_lock guard(hKernel->Mutex); + auto singleDeviceKernel = hKernel->getSingleDeviceKernel(hDevice.get()); + if (!singleDeviceKernel.has_value()) { + return UR_RESULT_ERROR_INVALID_DEVICE; + } for (uint32_t argIndex = 0; argIndex < numArgs; argIndex++) { switch (pArgs[argIndex].type) { case UR_EXP_KERNEL_ARG_TYPE_LOCAL: - UR_CALL(hKernel->setArgValue(pArgs[argIndex].index, - pArgs[argIndex].size, nullptr, nullptr)); + UR_CALL(singleDeviceKernel->get().setArgValue( + pArgs[argIndex].index, pArgs[argIndex].size, nullptr, nullptr)); break; case UR_EXP_KERNEL_ARG_TYPE_VALUE: - UR_CALL(hKernel->setArgValue(pArgs[argIndex].index, - pArgs[argIndex].size, nullptr, - pArgs[argIndex].value.value)); + UR_CALL(singleDeviceKernel->get().setArgValue( + pArgs[argIndex].index, pArgs[argIndex].size, nullptr, + pArgs[argIndex].value.value)); break; case UR_EXP_KERNEL_ARG_TYPE_POINTER: - UR_CALL(hKernel->setArgPointer(pArgs[argIndex].index, nullptr, - pArgs[argIndex].value.pointer)); + UR_CALL(singleDeviceKernel->get().setArgPointer( + pArgs[argIndex].index, nullptr, pArgs[argIndex].value.pointer)); break; case UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ: // TODO: import helper for converting ur flags to internal equivalent @@ -1013,9 +1017,9 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp( pArgs[argIndex].index})); break; case UR_EXP_KERNEL_ARG_TYPE_SAMPLER: { - UR_CALL( - hKernel->setArgValue(argIndex, sizeof(void *), nullptr, - &pArgs[argIndex].value.sampler->ZeSampler)); + UR_CALL(singleDeviceKernel->get().setArgValue( + argIndex, sizeof(void *), nullptr, + &pArgs[argIndex].value.sampler->ZeSampler)); break; } default: diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp index f48a41154e0f7..296d4fedb9a31 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp @@ -180,11 +180,8 @@ ur_kernel_handle_t_::getProperties(ur_device_handle_t hDevice) const { return deviceKernel.zeKernelProperties.get(); } -ur_result_t ur_kernel_handle_t_::setArgValue( - uint32_t argIndex, size_t argSize, - const ur_kernel_arg_value_properties_t * /*pProperties*/, - const void *pArgValue) { - +namespace { +inline const void *normalizePointerArg(size_t argSize, const void *pArgValue) { // OpenCL: "the arg_value pointer can be NULL or point to a NULL value // in which case a NULL value will be used as the value for the argument // declared as a pointer to global or constant memory in the kernel" @@ -194,31 +191,59 @@ ur_result_t ur_kernel_handle_t_::setArgValue( // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL. if (argSize == sizeof(void *) && pArgValue && *(void **)(const_cast(pArgValue)) == nullptr) { - pArgValue = nullptr; + return nullptr; } - - if (argIndex > zeCommonProperties->numKernelArgs - 1) { - return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX; + return pArgValue; +} +ur_result_t setArgValueHelper(ze_kernel_handle_t zeKernel, uint32_t argIndex, + size_t argSize, const void *pArgValue) { + pArgValue = normalizePointerArg(argSize, pArgValue); + auto zeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue, + (zeKernel, argIndex, argSize, pArgValue)); + if (zeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) { + return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE; + } else if (zeResult != ZE_RESULT_SUCCESS) { + return ze2urResult(zeResult); } + return UR_RESULT_SUCCESS; +} +} // namespace + +ur_result_t ur_single_device_kernel_t::setArgValue( + uint32_t argIndex, size_t argSize, + const ur_kernel_arg_value_properties_t * /*pProperties*/, + const void *pArgValue) { + pArgValue = normalizePointerArg(argSize, pArgValue); + return setArgValueHelper(hKernel.get(), argIndex, argSize, pArgValue); +} +ur_result_t ur_kernel_handle_t_::setArgValue( + uint32_t argIndex, size_t argSize, + const ur_kernel_arg_value_properties_t * /*pProperties*/, + const void *pArgValue) { + + pArgValue = normalizePointerArg(argSize, pArgValue); for (auto &singleDeviceKernel : deviceKernels) { if (!singleDeviceKernel.has_value()) { continue; } - auto zeResult = ZE_CALL_NOCHECK(zeKernelSetArgumentValue, - (singleDeviceKernel.value().hKernel.get(), - argIndex, argSize, pArgValue)); - - if (zeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) { - return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE; - } else if (zeResult != ZE_RESULT_SUCCESS) { - return ze2urResult(zeResult); - } + auto Result = setArgValueHelper(singleDeviceKernel.value().hKernel.get(), + argIndex, argSize, pArgValue); + if (Result != UR_RESULT_SUCCESS) + return Result; } return UR_RESULT_SUCCESS; } +ur_result_t ur_single_device_kernel_t::setArgPointer( + uint32_t argIndex, + const ur_kernel_arg_pointer_properties_t * /*pProperties*/, + const void *pArgValue) { + // KernelSetArgValue is expecting a pointer to the argument + return setArgValue(argIndex, sizeof(const void *), nullptr, &pArgValue); +} + ur_result_t ur_kernel_handle_t_::setArgPointer( uint32_t argIndex, const ur_kernel_arg_pointer_properties_t * /*pProperties*/, @@ -232,6 +257,15 @@ ur_program_handle_t ur_kernel_handle_t_::getProgramHandle() const { return hProgram; } +std::optional> +ur_kernel_handle_t_::getSingleDeviceKernel(ur_device_handle_t hDevice) { + size_t index = deviceIndex(hDevice); + if (index >= deviceKernels.size() || !deviceKernels[index]) { + return std::nullopt; + } + return std::ref(deviceKernels[index].value()); +} + ur_result_t ur_kernel_handle_t_::setExecInfo(ur_kernel_exec_info_t propName, const void *pPropValue) { for (auto &kernel : deviceKernels) { diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp index 9c823760c42f2..a5689c4c9d6a9 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp @@ -24,6 +24,17 @@ struct ur_single_device_kernel_t { ur_device_handle_t hDevice; v2::raii::ze_kernel_handle_t hKernel; mutable ZeCache> zeKernelProperties; + + // Implementation of urKernelSetArgValue. + ur_result_t setArgValue(uint32_t argIndex, size_t argSize, + const ur_kernel_arg_value_properties_t *pProperties, + const void *pArgValue); + + // Implementation of urKernelSetArgPointer. + ur_result_t + setArgPointer(uint32_t argIndex, + const ur_kernel_arg_pointer_properties_t *pProperties, + const void *pArgValue); }; struct ur_kernel_handle_t_ : ur_object { @@ -94,6 +105,9 @@ struct ur_kernel_handle_t_ : ur_object { ur::RefCount RefCount; + std::optional> + getSingleDeviceKernel(ur_device_handle_t hDevice); + private: // Keep the program of the kernel. const ur_program_handle_t hProgram;