Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (C) 2024 Intel Corporation
* Copyright (C) 2024-2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
Expand Down Expand Up @@ -85,4 +85,15 @@ kernel void testGlobalOffset(global int *globalOffsets) {
globalOffsets[1] += get_global_offset(1);
globalOffsets[2] += get_global_offset(2);
}
}

kernel void test_slm_mutation(global uint *out, local uint *slm_1, local uint *slm_2, uint value) {
uint gsize = get_global_size(0);
uint gid = get_global_id(0);
uint lid = get_local_id(0);

slm_1[lid] = lid + value;
slm_2[lid] = -lid + value;
barrier(CLK_LOCAL_MEM_FENCE);
out[gid] = gsize + slm_1[lid] + slm_2[lid];
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (C) 2024 Intel Corporation
* Copyright (C) 2024-2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
Expand Down Expand Up @@ -199,6 +199,104 @@ LZT_TEST_F(
lzt::destroy_function(addKernel);
}

class zeMutableCommandListSLMTests
: public zeMutableCommandListTests,
public ::testing::WithParamInterface<
std::tuple<uint32_t, uint32_t, uint32_t>> {};

LZT_TEST_P(
zeMutableCommandListSLMTests,
GivenMutationOfSLMKernelArgumentsWhenCommandListIsClosedThenArgumentsWereReplaced) {
if (!kernelArgumentsSupport || !groupSizeSupport) {
GTEST_SKIP() << "Not all required extensions are supported";
}
uint32_t group_size_x = std::get<0>(GetParam());
uint32_t mutated_group_size_x = std::get<1>(GetParam());
uint32_t group_count_x = std::get<2>(GetParam());

uint32_t global_size = group_size_x * group_count_x;
uint32_t mutated_global_size = mutated_group_size_x * group_count_x;

uint32_t verify_value = 21u;
uint32_t *output = reinterpret_cast<uint32_t *>(
lzt::allocate_host_memory(global_size * sizeof(uint32_t)));
std::memset(output, 0, global_size * sizeof(uint32_t));

ze_kernel_handle_t slm_kernel =
lzt::create_function(module, "test_slm_mutation");

lzt::set_group_size(slm_kernel, group_size_x, 1, 1);
lzt::set_argument_value(slm_kernel, 0, sizeof(void *), &output);
lzt::set_argument_value(slm_kernel, 1, group_size_x * sizeof(uint32_t),
nullptr);
lzt::set_argument_value(slm_kernel, 2, group_size_x * sizeof(uint32_t),
nullptr);
lzt::set_argument_value(slm_kernel, 3, sizeof(uint32_t), &verify_value);

uint64_t command_id = 0;
commandIdDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS |
ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE;
EXPECT_ZE_RESULT_SUCCESS(zeCommandListGetNextCommandIdExp(
mutableCmdList, &commandIdDesc, &command_id));
ze_group_count_t group_count{group_count_x, 1, 1};
lzt::append_launch_function(mutableCmdList, slm_kernel, &group_count, nullptr,
0, nullptr);
lzt::close_command_list(mutableCmdList);
lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr);
lzt::synchronize(queue, std::numeric_limits<uint64_t>::max());

for (uint32_t i = 0; i < global_size; i++) {
EXPECT_EQ(output[i], global_size + verify_value * 2);
}

lzt::free_memory(output);
output = reinterpret_cast<uint32_t *>(
lzt::allocate_host_memory(mutated_global_size * sizeof(uint32_t)));
std::memset(output, 0, mutated_global_size * sizeof(uint32_t));

ze_mutable_kernel_argument_exp_desc_t mutate_kernel_slm_arg_2 = {
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_kernel_slm_arg_2.commandId = command_id;
mutate_kernel_slm_arg_2.argIndex = 2;
mutate_kernel_slm_arg_2.argSize = mutated_group_size_x * sizeof(uint32_t);
mutate_kernel_slm_arg_2.pArgValue = nullptr;
ze_mutable_kernel_argument_exp_desc_t mutate_kernel_slm_arg_1 = {
ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC};
mutate_kernel_slm_arg_1.commandId = command_id;
mutate_kernel_slm_arg_1.argIndex = 1;
mutate_kernel_slm_arg_1.argSize = mutated_group_size_x * sizeof(uint32_t);
mutate_kernel_slm_arg_1.pArgValue = nullptr;
mutate_kernel_slm_arg_1.pNext = &mutate_kernel_slm_arg_2;
ze_mutable_group_size_exp_desc_t mutate_group_size_desc = {
ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC};
mutate_group_size_desc.commandId = command_id;
mutate_group_size_desc.groupSizeX = mutated_group_size_x;
mutate_group_size_desc.groupSizeY = 1;
mutate_group_size_desc.groupSizeZ = 1;
mutate_group_size_desc.pNext = &mutate_kernel_slm_arg_1;
mutableCmdDesc.pNext = &mutate_group_size_desc;

EXPECT_ZE_RESULT_SUCCESS(
zeCommandListUpdateMutableCommandsExp(mutableCmdList, &mutableCmdDesc));

lzt::close_command_list(mutableCmdList);
lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr);
lzt::synchronize(queue, std::numeric_limits<uint64_t>::max());

for (uint32_t i = 0; i < mutated_global_size; i++) {
EXPECT_EQ(output[i], mutated_global_size + verify_value * 2);
}

lzt::free_memory(output);
lzt::destroy_function(slm_kernel);
}

INSTANTIATE_TEST_SUITE_P(MutableCommandListSLMTests,
zeMutableCommandListSLMTests,
::testing::Combine(::testing::Values(1, 16, 32, 64),
::testing::Values(1, 16, 32, 64),
::testing::Values(1, 2)));

LZT_TEST_F(
zeMutableCommandListTests,
GivenMutationOfGroupCountWhenCommandListIsClosedThenGlobalWorkSizeIsUpdated) {
Expand Down