Skip to content

Revert "[UR][SYCL] Introduce UR api to set kernel args + launch in one call." #19661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 50 additions & 157 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2310,14 +2310,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
}
}

// Gets UR argument struct for a given kernel and device based on the argument
// type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
// the graphs extension (LaunchWithArgs for graphs is planned future work).
static void GetUrArgsBasedOnType(
// Sets arguments for a given kernel and device based on the argument type.
// Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
// extension.
static void SetArgBasedOnType(
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
device_image_impl *DeviceImageImpl,
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
switch (Arg.MType) {
case kernel_param_kind_t::kind_dynamic_work_group_memory:
break;
Expand All @@ -2337,61 +2337,52 @@ static void GetUrArgsBasedOnType(
getMemAllocationFunc
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
: nullptr;
ur_exp_kernel_arg_value_t Value = {};
Value.memObjTuple = {MemArg, AccessModeToUr(Req->MAccessMode)};
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
static_cast<uint32_t>(NextTrueIndex), sizeof(MemArg),
Value});
ur_kernel_arg_mem_obj_properties_t MemObjData{};
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
&MemObjData, MemArg);
break;
}
case kernel_param_kind_t::kind_std_layout: {
ur_exp_kernel_arg_type_t Type;
if (Arg.MPtr) {
Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
Adapter.call<UrApiKind::urKernelSetArgValue>(
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
} else {
Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
Arg.MSize, nullptr);
}
ur_exp_kernel_arg_value_t Value = {};
Value.value = {Arg.MPtr};
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
Type, static_cast<uint32_t>(NextTrueIndex),
static_cast<size_t>(Arg.MSize), Value});

break;
}
case kernel_param_kind_t::kind_sampler: {
sampler *SamplerPtr = (sampler *)Arg.MPtr;
ur_exp_kernel_arg_value_t Value = {};
Value.sampler = (ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
->getOrCreateSampler(ContextImpl);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
static_cast<uint32_t>(NextTrueIndex),
sizeof(ur_sampler_handle_t), Value});
ur_sampler_handle_t Sampler =
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
->getOrCreateSampler(ContextImpl);
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
nullptr, Sampler);
break;
}
case kernel_param_kind_t::kind_pointer: {
ur_exp_kernel_arg_value_t Value = {};
// We need to de-rerence to get the actual USM allocation - that's the
// We need to de-rerence this to get the actual USM allocation - that's the
// pointer UR is expecting.
Value.pointer = *static_cast<void *const *>(Arg.MPtr);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_POINTER,
static_cast<uint32_t>(NextTrueIndex), sizeof(Arg.MPtr),
Value});
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
break;
}
case kernel_param_kind_t::kind_specialization_constants_buffer: {
assert(DeviceImageImpl != nullptr);
ur_mem_handle_t SpecConstsBuffer =
DeviceImageImpl->get_spec_const_buffer_ref();
ur_exp_kernel_arg_value_t Value = {};
Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
static_cast<uint32_t>(NextTrueIndex),
sizeof(SpecConstsBuffer), Value});

ur_kernel_arg_mem_obj_properties_t MemObjProps{};
MemObjProps.pNext = nullptr;
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
break;
}
case kernel_param_kind_t::kind_invalid:
Expand Down Expand Up @@ -2424,32 +2415,22 @@ static ur_result_t SetKernelParamsAndLaunch(
DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref() : Empty);
}

std::vector<ur_exp_kernel_arg_properties_t> UrArgs;
UrArgs.reserve(Args.size());

if (KernelFuncPtr && !KernelHasSpecialCaptures) {
auto setFunc = [&UrArgs,
auto setFunc = [&Adapter, Kernel,
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
size_t NextTrueIndex) {
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
switch (ParamDesc.kind) {
case kernel_param_kind_t::kind_std_layout: {
int Size = ParamDesc.info;
ur_exp_kernel_arg_value_t Value = {};
Value.value = ArgPtr;
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_VALUE,
static_cast<uint32_t>(NextTrueIndex),
static_cast<size_t>(Size), Value});
Adapter.call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
Size, nullptr, ArgPtr);
break;
}
case kernel_param_kind_t::kind_pointer: {
ur_exp_kernel_arg_value_t Value = {};
Value.pointer = *static_cast<const void *const *>(ArgPtr);
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr,
UR_EXP_KERNEL_ARG_TYPE_POINTER,
static_cast<uint32_t>(NextTrueIndex),
sizeof(Value.pointer), Value});
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
break;
}
default:
Expand All @@ -2459,10 +2440,10 @@ static ur_result_t SetKernelParamsAndLaunch(
applyFuncOnFilteredArgs(EliminatedArgMask, KernelNumArgs,
KernelParamDescGetter, setFunc);
} else {
auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc, &Queue,
&UrArgs](detail::ArgDesc &Arg, size_t NextTrueIndex) {
GetUrArgsBasedOnType(DeviceImageImpl, getMemAllocationFunc,
Queue.getContextImpl(), Arg, NextTrueIndex, UrArgs);
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
Queue.getContextImpl(), Arg, NextTrueIndex);
};
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
}
Expand All @@ -2475,12 +2456,8 @@ static ur_result_t SetKernelParamsAndLaunch(
// CUDA-style local memory setting. Note that we may have -1 as a position,
// this indicates the buffer is actually unused and was elided.
if (ImplicitLocalArg.has_value() && ImplicitLocalArg.value() != -1) {
UrArgs.push_back({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
nullptr,
UR_EXP_KERNEL_ARG_TYPE_LOCAL,
static_cast<uint32_t>(ImplicitLocalArg.value()),
WorkGroupMemorySize,
{nullptr}});
Adapter.call<UrApiKind::urKernelSetArgLocal>(
Kernel, ImplicitLocalArg.value(), WorkGroupMemorySize, nullptr);
}

adjustNDRangePerKernel(NDRDesc, Kernel, Queue.getDeviceImpl());
Expand Down Expand Up @@ -2538,104 +2515,20 @@ static ur_result_t SetKernelParamsAndLaunch(
{{WorkGroupMemorySize}}});
}
ur_event_handle_t UREvent = nullptr;
ur_result_t Error =
Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr,
&NDRDesc.GlobalSize[0], LocalSize, UrArgs.size(), UrArgs.data(),
property_list.size(),
property_list.empty() ? nullptr : property_list.data(),
RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0],
OutEventImpl ? &UREvent : nullptr);
ur_result_t Error = Adapter.call_nocheck<UrApiKind::urEnqueueKernelLaunch>(
Queue.getHandleRef(), Kernel, NDRDesc.Dims,
HasOffset ? &NDRDesc.GlobalOffset[0] : nullptr, &NDRDesc.GlobalSize[0],
LocalSize, property_list.size(),
property_list.empty() ? nullptr : property_list.data(), RawEvents.size(),
RawEvents.empty() ? nullptr : &RawEvents[0],
OutEventImpl ? &UREvent : nullptr);
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
OutEventImpl->setHandle(UREvent);
}

return Error;
}

// Sets arguments for a given kernel and device based on the argument type.
// This is a legacy path which the graphs extension still uses.
static void SetArgBasedOnType(
adapter_impl &Adapter, ur_kernel_handle_t Kernel,
device_image_impl *DeviceImageImpl,
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
switch (Arg.MType) {
case kernel_param_kind_t::kind_dynamic_work_group_memory:
break;
case kernel_param_kind_t::kind_work_group_memory:
break;
case kernel_param_kind_t::kind_stream:
break;
case kernel_param_kind_t::kind_dynamic_accessor:
case kernel_param_kind_t::kind_accessor: {
Requirement *Req = (Requirement *)(Arg.MPtr);

// getMemAllocationFunc is nullptr when there are no requirements. However,
// we may pass default constructed accessors to a command, which don't add
// requirements. In such case, getMemAllocationFunc is nullptr, but it's a
// valid case, so we need to properly handle it.
ur_mem_handle_t MemArg =
getMemAllocationFunc
? reinterpret_cast<ur_mem_handle_t>(getMemAllocationFunc(Req))
: nullptr;
ur_kernel_arg_mem_obj_properties_t MemObjData{};
MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjData.memoryAccess = AccessModeToUr(Req->MAccessMode);
Adapter.call<UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
&MemObjData, MemArg);
break;
}
case kernel_param_kind_t::kind_std_layout: {
if (Arg.MPtr) {
Adapter.call<UrApiKind::urKernelSetArgValue>(
Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
} else {
Adapter.call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
Arg.MSize, nullptr);
}

break;
}
case kernel_param_kind_t::kind_sampler: {
sampler *SamplerPtr = (sampler *)Arg.MPtr;
ur_sampler_handle_t Sampler =
(ur_sampler_handle_t)detail::getSyclObjImpl(*SamplerPtr)
->getOrCreateSampler(ContextImpl);
Adapter.call<UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
nullptr, Sampler);
break;
}
case kernel_param_kind_t::kind_pointer: {
// We need to de-rerence this to get the actual USM allocation - that's the
// pointer UR is expecting.
const void *Ptr = *static_cast<const void *const *>(Arg.MPtr);
Adapter.call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
break;
}
case kernel_param_kind_t::kind_specialization_constants_buffer: {
assert(DeviceImageImpl != nullptr);
ur_mem_handle_t SpecConstsBuffer =
DeviceImageImpl->get_spec_const_buffer_ref();

ur_kernel_arg_mem_obj_properties_t MemObjProps{};
MemObjProps.pNext = nullptr;
MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
Adapter.call<UrApiKind::urKernelSetArgMemObj>(
Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
break;
}
case kernel_param_kind_t::kind_invalid:
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
"Invalid kernel param kind " +
codeToString(UR_RESULT_ERROR_INVALID_VALUE));
break;
}
}

static std::tuple<ur_kernel_handle_t, device_image_impl *,
const KernelArgMask *>
getCGKernelInfo(const CGExecKernel &CommandGroup, context_impl &ContextImpl,
Expand Down
4 changes: 2 additions & 2 deletions sycl/test-e2e/Adapters/level_zero/batch_barrier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int main(int argc, char *argv[]) {
queue q;

submit_kernel(q); // starts a batch
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
// CHECK: ---> urEnqueueKernelLaunch
// CHECK-NOT: zeCommandQueueExecuteCommandLists

// Initialize Level Zero driver is required if this test is linked
Expand All @@ -41,7 +41,7 @@ int main(int argc, char *argv[]) {
// CHECK-NOT: zeCommandQueueExecuteCommandLists

submit_kernel(q);
// CHECK: ---> urEnqueueKernelLaunchWithArgsExp
// CHECK: ---> urEnqueueKernelLaunch
// CHECK-NOT: zeCommandQueueExecuteCommandLists

// interop should close the batch
Expand Down
Loading