|
1 | 1 | /* |
2 | 2 | * |
3 | | - * Copyright (C) 2024 Intel Corporation |
| 3 | + * Copyright (C) 2024-2025 Intel Corporation |
4 | 4 | * |
5 | 5 | * SPDX-License-Identifier: MIT |
6 | 6 | * |
@@ -199,6 +199,104 @@ LZT_TEST_F( |
199 | 199 | lzt::destroy_function(addKernel); |
200 | 200 | } |
201 | 201 |
|
| 202 | +class zeMutableCommandListSLMTests |
| 203 | + : public zeMutableCommandListTests, |
| 204 | + public ::testing::WithParamInterface< |
| 205 | + std::tuple<uint32_t, uint32_t, uint32_t>> {}; |
| 206 | + |
| 207 | +LZT_TEST_P( |
| 208 | + zeMutableCommandListSLMTests, |
| 209 | + GivenMutationOfSLMKernelArgumentsWhenCommandListIsClosedThenArgumentsWereReplaced) { |
| 210 | + if (!kernelArgumentsSupport || !groupSizeSupport) { |
| 211 | + GTEST_SKIP() << "Not all required extensions are supported"; |
| 212 | + } |
| 213 | + uint32_t group_size_x = std::get<0>(GetParam()); |
| 214 | + uint32_t mutated_group_size_x = std::get<1>(GetParam()); |
| 215 | + uint32_t group_count_x = std::get<2>(GetParam()); |
| 216 | + |
| 217 | + uint32_t global_size = group_size_x * group_count_x; |
| 218 | + uint32_t mutated_global_size = mutated_group_size_x * group_count_x; |
| 219 | + |
| 220 | + uint32_t verify_value = 21u; |
| 221 | + uint32_t *output = reinterpret_cast<uint32_t *>( |
| 222 | + lzt::allocate_host_memory(global_size * sizeof(uint32_t))); |
| 223 | + std::memset(output, 0, global_size * sizeof(uint32_t)); |
| 224 | + |
| 225 | + ze_kernel_handle_t slm_kernel = |
| 226 | + lzt::create_function(module, "test_slm_mutation"); |
| 227 | + |
| 228 | + lzt::set_group_size(slm_kernel, group_size_x, 1, 1); |
| 229 | + lzt::set_argument_value(slm_kernel, 0, sizeof(void *), &output); |
| 230 | + lzt::set_argument_value(slm_kernel, 1, group_size_x * sizeof(uint32_t), |
| 231 | + nullptr); |
| 232 | + lzt::set_argument_value(slm_kernel, 2, group_size_x * sizeof(uint32_t), |
| 233 | + nullptr); |
| 234 | + lzt::set_argument_value(slm_kernel, 3, sizeof(uint32_t), &verify_value); |
| 235 | + |
| 236 | + uint64_t command_id = 0; |
| 237 | + commandIdDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS | |
| 238 | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE; |
| 239 | + EXPECT_ZE_RESULT_SUCCESS(zeCommandListGetNextCommandIdExp( |
| 240 | + mutableCmdList, &commandIdDesc, &command_id)); |
| 241 | + ze_group_count_t group_count{group_count_x, 1, 1}; |
| 242 | + lzt::append_launch_function(mutableCmdList, slm_kernel, &group_count, nullptr, |
| 243 | + 0, nullptr); |
| 244 | + lzt::close_command_list(mutableCmdList); |
| 245 | + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); |
| 246 | + lzt::synchronize(queue, std::numeric_limits<uint64_t>::max()); |
| 247 | + |
| 248 | + for (uint32_t i = 0; i < global_size; i++) { |
| 249 | + EXPECT_EQ(output[i], global_size + verify_value * 2); |
| 250 | + } |
| 251 | + |
| 252 | + lzt::free_memory(output); |
| 253 | + output = reinterpret_cast<uint32_t *>( |
| 254 | + lzt::allocate_host_memory(mutated_global_size * sizeof(uint32_t))); |
| 255 | + std::memset(output, 0, mutated_global_size * sizeof(uint32_t)); |
| 256 | + |
| 257 | + ze_mutable_kernel_argument_exp_desc_t mutate_kernel_slm_arg_2 = { |
| 258 | + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; |
| 259 | + mutate_kernel_slm_arg_2.commandId = command_id; |
| 260 | + mutate_kernel_slm_arg_2.argIndex = 2; |
| 261 | + mutate_kernel_slm_arg_2.argSize = mutated_group_size_x * sizeof(uint32_t); |
| 262 | + mutate_kernel_slm_arg_2.pArgValue = nullptr; |
| 263 | + ze_mutable_kernel_argument_exp_desc_t mutate_kernel_slm_arg_1 = { |
| 264 | + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; |
| 265 | + mutate_kernel_slm_arg_1.commandId = command_id; |
| 266 | + mutate_kernel_slm_arg_1.argIndex = 1; |
| 267 | + mutate_kernel_slm_arg_1.argSize = mutated_group_size_x * sizeof(uint32_t); |
| 268 | + mutate_kernel_slm_arg_1.pArgValue = nullptr; |
| 269 | + mutate_kernel_slm_arg_1.pNext = &mutate_kernel_slm_arg_2; |
| 270 | + ze_mutable_group_size_exp_desc_t mutate_group_size_desc = { |
| 271 | + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; |
| 272 | + mutate_group_size_desc.commandId = command_id; |
| 273 | + mutate_group_size_desc.groupSizeX = mutated_group_size_x; |
| 274 | + mutate_group_size_desc.groupSizeY = 1; |
| 275 | + mutate_group_size_desc.groupSizeZ = 1; |
| 276 | + mutate_group_size_desc.pNext = &mutate_kernel_slm_arg_1; |
| 277 | + mutableCmdDesc.pNext = &mutate_group_size_desc; |
| 278 | + |
| 279 | + EXPECT_ZE_RESULT_SUCCESS( |
| 280 | + zeCommandListUpdateMutableCommandsExp(mutableCmdList, &mutableCmdDesc)); |
| 281 | + |
| 282 | + lzt::close_command_list(mutableCmdList); |
| 283 | + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); |
| 284 | + lzt::synchronize(queue, std::numeric_limits<uint64_t>::max()); |
| 285 | + |
| 286 | + for (uint32_t i = 0; i < mutated_global_size; i++) { |
| 287 | + EXPECT_EQ(output[i], mutated_global_size + verify_value * 2); |
| 288 | + } |
| 289 | + |
| 290 | + lzt::free_memory(output); |
| 291 | + lzt::destroy_function(slm_kernel); |
| 292 | +} |
| 293 | + |
| 294 | +INSTANTIATE_TEST_SUITE_P(MutableCommandListSLMTests, |
| 295 | + zeMutableCommandListSLMTests, |
| 296 | + ::testing::Combine(::testing::Values(1, 16, 32, 64), |
| 297 | + ::testing::Values(1, 16, 32, 64), |
| 298 | + ::testing::Values(1, 2))); |
| 299 | + |
202 | 300 | LZT_TEST_F( |
203 | 301 | zeMutableCommandListTests, |
204 | 302 | GivenMutationOfGroupCountWhenCommandListIsClosedThenGlobalWorkSizeIsUpdated) { |
|
0 commit comments