diff --git a/tools/common_lib/src/dml_base_node.h b/tools/common_lib/src/dml_base_node.h index 6c226bc..60f54f2 100644 --- a/tools/common_lib/src/dml_base_node.h +++ b/tools/common_lib/src/dml_base_node.h @@ -150,6 +150,7 @@ inline dml::TensorPolicy to_dml_tensor_policy(DataLayout layout) switch (layout) { case DataLayout::eCHW: + case DataLayout::eNCDHW: case DataLayout::eNCHW: return dml::TensorPolicy::Default(); case DataLayout::eNHWC: return dml::TensorPolicy::InterleavedChannel(); case DataLayout::eW: return dml::TensorPolicy(compute_w_tensor_policy); diff --git a/tools/common_lib/src/layers_utils.h b/tools/common_lib/src/layers_utils.h index c3e909d..b890ce1 100644 --- a/tools/common_lib/src/layers_utils.h +++ b/tools/common_lib/src/layers_utils.h @@ -160,8 +160,9 @@ enum class DataLayout eNHWC_AlignH48, // example layout for unpacked tensor, ToDo: refactor cross runner to work with strides instead of hardcoded data layouts eCHW, // 3d dims for GEMMS eW, - - + eNCDHW, + eNHW, + eNCH, // .. // .. @@ -194,6 +195,7 @@ inline std::string data_layout_name(DataLayout l) case DataLayout::eIO_i8_o8_i2: return "IO_i8_o8_i2"; case DataLayout::eOYXI_o8: return "OYXI_o8"; case DataLayout::eOYXI_o16: return "OYXI_o16"; + case DataLayout::eNCDHW: return "Stacked_qkv"; default: assert(false && "Unknown data layout name."); return ""; @@ -215,6 +217,8 @@ inline std::uint8_t data_layout_dimensions_count(DataLayout l) return 3; case DataLayout::eW: return 1; + case DataLayout::eNCDHW: + return 5; default: return 0; } @@ -252,6 +256,7 @@ inline std::size_t data_layout_h_alignment(const DataLayout l) inline TensorShape data_layout_to_strides(TensorShape shape, DataLayout l) { const auto c = shape.c; + const auto d = shape.d; const auto h = static_cast(align(shape.h, data_layout_h_alignment(l))); const auto w = static_cast(align(shape.w, data_layout_w_alignment(l))); @@ -279,6 +284,28 @@ inline TensorShape data_layout_to_strides(TensorShape shape, DataLayout l) c); break; } + case DataLayout::eNCDHW: + { + ret = TensorShape( + c * h * d * w, + h * d* w, + d * w, + w, + 1); + break; + case DataLayout::eNHW: + ret = TensorShape( + h * w, + w, + 1); + break; + case DataLayout::eNCH: + ret = TensorShape( + c * h, + h, + 1); + break; + } default: assert("!unsupported right now"); } @@ -361,8 +388,8 @@ inline auto add_data_type_cli_option(CLI::App* opts, std::string_view opt_name, inline auto add_data_layout_cli_option(CLI::App* opts, std::string_view opt_name, DataLayout& layout) { - return opts->add_option(opt_name.data(), layout)->check(CLI::IsMember({DataLayout::eNCHW, DataLayout::eNHWC, DataLayout::eW, DataLayout::eCHW, DataLayout::eNCHW_AlignW320, DataLayout::eNHWC_AlignH48 })) + return opts->add_option(opt_name.data(), layout)->check(CLI::IsMember({DataLayout::eNCHW, DataLayout::eNHWC, DataLayout::eW, DataLayout::eCHW, DataLayout::eNCHW_AlignW320, DataLayout::eNHWC_AlignH48, DataLayout::eNCDHW })) ->transform(CLI::Transformer(std::map{ - {"nchw", DataLayout::eNCHW}, { "nhwc", DataLayout::eNHWC }, { "w", DataLayout::eW }, { "chw", DataLayout::eCHW }, { "nchw_alignw320", DataLayout::eNCHW_AlignW320 }, { "nhwc_alignh48", DataLayout::eNHWC_AlignH48 }, + {"nchw", DataLayout::eNCHW}, { "nhwc", DataLayout::eNHWC }, { "w", DataLayout::eW }, { "chw", DataLayout::eCHW }, { "nchw_alignw320", DataLayout::eNCHW_AlignW320 }, { "nhwc_alignh48", DataLayout::eNHWC_AlignH48 }, { "ncdhw", DataLayout::eNCDHW }, }, CLI::ignore_case, CLI::ignore_underscore)); } diff --git a/tools/common_lib/src/mha.h b/tools/common_lib/src/mha.h index 41004a3..a89683c 100644 --- a/tools/common_lib/src/mha.h +++ b/tools/common_lib/src/mha.h @@ -9,6 +9,7 @@ enum class MhaType { // qkv MhaType_QKV = 1, + MhaType_Q_KV = 2, }; namespace gpu_op { @@ -16,22 +17,24 @@ namespace gpu_op { public: Mha(MhaType mha_type, const DML_TENSOR_DATA_TYPE data_type, const dml::TensorPolicy& tensor_policy, - const TensorShape& shape_input, const TensorShape& shape_out, + const TensorShape& shape_input_qkv, const TensorShape& shape_input_q, const TensorShape& shape_input_kv, const TensorShape& shape_out, IDMLDevice* dml_device, ID3D12Device* d3d12_device, bool disable_mc = false) :DirectMlBaseNode(dml_device, d3d12_device) { - if (mha_type == MhaType::MhaType_QKV) // todo: move some codes out as general code for other mha types + type_ = mha_type; + if (mha_type == MhaType::MhaType_QKV || mha_type == MhaType::MhaType_Q_KV) // todo: move some codes out as general code for other mha types { - - const dml::TensorDimensions input_dims{ shape_input.n, shape_input.c, shape_input.d, shape_input.h, shape_input.w}; + const dml::TensorDimensions input_dims_qkv{ shape_input_qkv.n, shape_input_qkv.c, shape_input_qkv.h, shape_input_qkv.d, shape_input_qkv.w }; + const dml::TensorDimensions input_dims_q{ shape_input_q.n, shape_input_q.c, shape_input_q.h }; + const dml::TensorDimensions input_dims_kv{ shape_input_kv.n, shape_input_kv.c, shape_input_kv.h, shape_input_kv.d, shape_input_kv.w }; const dml::TensorDimensions output_dims{ shape_out.n, shape_out.h, shape_out.w }; dml::TensorProperties stacked_qkv_tensor_properites{}; { tensor_stacked_qkv_desc_.DataType = data_type; tensor_stacked_qkv_desc_.Flags = DML_TENSOR_FLAG_NONE; - tensor_stacked_qkv_desc_.DimensionCount = static_cast(input_dims.size()); - tensor_stacked_qkv_desc_.Sizes = input_dims.data(); + tensor_stacked_qkv_desc_.DimensionCount = static_cast(input_dims_qkv.size()); + tensor_stacked_qkv_desc_.Sizes = input_dims_qkv.data(); //stacked_qkv_tensor_properites = input_tensor_policy.Get(tensor_stacked_qkv_desc_.DataType, tensor_stacked_qkv_desc_.Flags, input_dims); tensor_stacked_qkv_desc_.Strides = nullptr; // stacked_qkv_tensor_properites.strides.has_value() ? stacked_qkv_tensor_properites.strides->data() : nullptr; @@ -44,6 +47,41 @@ namespace gpu_op } + dml::TensorProperties stacked_q_tensor_properites{}; + { + tensor_stacked_q_desc_.DataType = data_type; + tensor_stacked_q_desc_.Flags = DML_TENSOR_FLAG_NONE; + tensor_stacked_q_desc_.DimensionCount = static_cast(input_dims_q.size()); + tensor_stacked_q_desc_.Sizes = input_dims_q.data(); + + //stacked_qkv_tensor_properites = input_tensor_policy.Get(tensor_stacked_qkv_desc_.DataType, tensor_stacked_qkv_desc_.Flags, input_dims); + tensor_stacked_q_desc_.Strides = nullptr; // stacked_qkv_tensor_properites.strides.has_value() ? stacked_qkv_tensor_properites.strides->data() : nullptr; + tensor_stacked_q_desc_.TotalTensorSizeInBytes = DMLCalcBufferTensorSize( + tensor_stacked_q_desc_.DataType, + tensor_stacked_q_desc_.DimensionCount, + tensor_stacked_q_desc_.Sizes, + tensor_stacked_q_desc_.Strides); //stacked_qkv_tensor_properites.totalTensorSizeInBytes; + tensor_stacked_q_desc_.GuaranteedBaseOffsetAlignment = 0; // stacked_qkv_tensor_properites.guaranteedBaseOffsetAlignment; + + } + + dml::TensorProperties stacked_kv_tensor_properites{}; + { + tensor_stacked_kv_desc_.DataType = data_type; + tensor_stacked_kv_desc_.Flags = DML_TENSOR_FLAG_NONE; + tensor_stacked_kv_desc_.DimensionCount = static_cast(input_dims_kv.size()); + tensor_stacked_kv_desc_.Sizes = input_dims_kv.data(); + + //stacked_qkv_tensor_properites = input_tensor_policy.Get(tensor_stacked_qkv_desc_.DataType, tensor_stacked_qkv_desc_.Flags, input_dims); + tensor_stacked_kv_desc_.Strides = nullptr; // stacked_qkv_tensor_properites.strides.has_value() ? stacked_qkv_tensor_properites.strides->data() : nullptr; + tensor_stacked_kv_desc_.TotalTensorSizeInBytes = DMLCalcBufferTensorSize( + tensor_stacked_kv_desc_.DataType, + tensor_stacked_kv_desc_.DimensionCount, + tensor_stacked_kv_desc_.Sizes, + tensor_stacked_kv_desc_.Strides); //stacked_qkv_tensor_properites.totalTensorSizeInBytes; + tensor_stacked_kv_desc_.GuaranteedBaseOffsetAlignment = 0; // stacked_qkv_tensor_properites.guaranteedBaseOffsetAlignment; + + } dml::TensorProperties output_tensor_properites; { tensor_out_desc_.DataType = data_type; @@ -67,9 +105,29 @@ namespace gpu_op DML_TENSOR_DESC stacked_qkv_tensor_desc = {}; stacked_qkv_tensor_desc.Desc = &tensor_stacked_qkv_desc_; stacked_qkv_tensor_desc.Type = DML_TENSOR_TYPE_BUFFER; + + DML_TENSOR_DESC stacked_q_tensor_desc = {}; + stacked_q_tensor_desc.Desc = &tensor_stacked_q_desc_; + stacked_q_tensor_desc.Type = DML_TENSOR_TYPE_BUFFER; + + DML_TENSOR_DESC stacked_kv_tensor_desc = {}; + stacked_kv_tensor_desc.Desc = &tensor_stacked_kv_desc_; + stacked_kv_tensor_desc.Type = DML_TENSOR_TYPE_BUFFER; DML_MULTIHEAD_ATTENTION_OPERATOR_DESC desc = {}; - desc.StackedQueryKeyValueTensor = &stacked_qkv_tensor_desc; + if (mha_type == MhaType::MhaType_QKV) + { + desc.StackedQueryKeyValueTensor = &stacked_qkv_tensor_desc; + } + else if (mha_type == MhaType::MhaType_Q_KV) + { + desc.QueryTensor = &stacked_q_tensor_desc; + desc.StackedKeyValueTensor = &stacked_kv_tensor_desc; + } + else + { + assert(false && "Unsupported MHA type!"); + } desc.OutputTensor = &output_desc; desc.Scale = 0.001f; desc.MaskFilterValue = -10000000; @@ -115,36 +173,73 @@ namespace gpu_op return tensor_out_desc_; } + MhaType get_mha_type() + { + return type_; + } void record_execute(IDMLCommandRecorder* dml_cmd_recorder, ID3D12GraphicsCommandList* cmd_list, - ID3D12Resource* resource_out, ID3D12Resource* resource_input) + ID3D12Resource* resource_out, ID3D12Resource* resource_input_qkv, ID3D12Resource* resource_input_q, ID3D12Resource* resource_input_kv) { - - assert(resource_input); - assert(resource_out); - DML_BUFFER_BINDING input_buffer_binding{ resource_input, 0, resource_input->GetDesc().Width }; - std::vector input_bindings; - input_bindings.reserve(11); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - input_bindings.push_back({ DML_BINDING_TYPE_BUFFER, &input_buffer_binding }); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); - - std::vector output_bindings; - output_bindings.reserve(3); - DML_BUFFER_BINDING output_buffer_binding{ resource_out, 0, resource_out->GetDesc().Width }; - DML_BINDING_DESC output_binding_desc{ DML_BINDING_TYPE_BUFFER, &output_buffer_binding }; - output_bindings.push_back({ DML_BINDING_TYPE_BUFFER, &output_buffer_binding }); - output_bindings.push_back({ DML_BINDING_TYPE_NONE, nullptr }); - output_bindings.push_back({ DML_BINDING_TYPE_NONE, nullptr }); - - record_execute_impl(dml_cmd_recorder, cmd_list, input_bindings, output_bindings); + if (type_ == MhaType::MhaType_QKV) + { + assert(resource_input_qkv); + assert(resource_out); + DML_BUFFER_BINDING input_buffer_binding_qkv{ resource_input_qkv, 0, resource_input_qkv->GetDesc().Width }; + std::vector input_bindings; + input_bindings.reserve(11); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_BUFFER, &input_buffer_binding_qkv }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + + std::vector output_bindings; + output_bindings.reserve(3); + DML_BUFFER_BINDING output_buffer_binding{ resource_out, 0, resource_out->GetDesc().Width }; + DML_BINDING_DESC output_binding_desc{ DML_BINDING_TYPE_BUFFER, &output_buffer_binding }; + output_bindings.push_back({ DML_BINDING_TYPE_BUFFER, &output_buffer_binding }); + output_bindings.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + output_bindings.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + + record_execute_impl(dml_cmd_recorder, cmd_list, input_bindings, output_bindings); + } + if (type_ == MhaType::MhaType_Q_KV) + { + assert(resource_input_q); + assert(resource_input_kv); + assert(resource_out); + DML_BUFFER_BINDING input_buffer_binding_q{ resource_input_q, 0, resource_input_q->GetDesc().Width }; + DML_BUFFER_BINDING input_buffer_binding_kv{ resource_input_kv, 0, resource_input_kv->GetDesc().Width }; + std::vector input_bindings; + input_bindings.reserve(11); + input_bindings.push_back({ DML_BINDING_TYPE_BUFFER , &input_buffer_binding_q }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_BUFFER , &input_buffer_binding_kv }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + input_bindings.push_back({ DML_BINDING_TYPE_NONE , nullptr }); + + std::vector output_bindings; + output_bindings.reserve(3); + DML_BUFFER_BINDING output_buffer_binding{ resource_out, 0, resource_out->GetDesc().Width }; + DML_BINDING_DESC output_binding_desc{ DML_BINDING_TYPE_BUFFER, &output_buffer_binding }; + output_bindings.push_back({ DML_BINDING_TYPE_BUFFER, &output_buffer_binding }); + output_bindings.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + output_bindings.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + + record_execute_impl(dml_cmd_recorder, cmd_list, input_bindings, output_bindings); + } } virtual void record_initialize(IDMLCommandRecorder* dml_cmd_recorder, ID3D12GraphicsCommandList* cmd_list) @@ -174,6 +269,8 @@ namespace gpu_op private: ComPtr dml_operator_; DML_BUFFER_TENSOR_DESC tensor_stacked_qkv_desc_; + DML_BUFFER_TENSOR_DESC tensor_stacked_q_desc_; + DML_BUFFER_TENSOR_DESC tensor_stacked_kv_desc_; DML_BUFFER_TENSOR_DESC tensor_out_desc_; MhaType type_{}; @@ -189,20 +286,25 @@ class MhaBaseDispatcher: public NodeDispatcher DataType dt; DataLayout layout; - TensorShape shape_input; - + TensorShape shape_input_qkv; + TensorShape shape_input_q; + TensorShape shape_input_kv; inline static void add_cli_options(CLI::App* opts, create_params_t& params) { add_data_type_cli_option(opts, "--data_type", params.dt)->required(); add_data_layout_cli_option(opts, "--layout", params.layout)->required(); - opts->add_option("--shape_input", params.shape_input)->required(); - opts->add_option("--mha_type", params.type, "Name of the type of MHA to run.") - ->check(CLI::IsMember({ MhaType::MhaType_QKV }))-> + ->check(CLI::IsMember({ MhaType::MhaType_QKV, MhaType::MhaType_Q_KV }))-> transform(CLI::Transformer(std::map{ { "qkv", MhaType::MhaType_QKV }, + { "q_kv", MhaType::MhaType_Q_KV }, }, CLI::ignore_case))->required(); + + opts->add_option("--shape_input_qkv", params.shape_input_qkv); + opts->add_option("--shape_input_q", params.shape_input_q); + opts->add_option("--shape_input_kv", params.shape_input_kv); + } }; @@ -214,12 +316,22 @@ class MhaBaseDispatcher: public NodeDispatcher , d3d12_device_(d3d12_device) { - input_data_.resize(get_tensor_elements_count(params_.shape_input, params_.layout) * get_data_type_bytes_width(params_.dt)); if (params_.type == MhaType::MhaType_QKV) { - assert(params_.shape_input.get_dims_count() == 5); - assert(!input_data_.empty()); - } else + input_data_stacked_qkv.resize(get_tensor_elements_count(params_.shape_input_qkv, params_.layout) * get_data_type_bytes_width(params_.dt)); + assert(params_.shape_input_qkv.get_dims_count() == 5); + assert(!input_data_stacked_qkv.empty()); + } + else if (params_.type == MhaType::MhaType_Q_KV) + { + input_data_stacked_q.resize(get_tensor_elements_count(params_.shape_input_q, DataLayout::eNCH) * get_data_type_bytes_width(params_.dt)); + input_data_stacked_kv.resize(get_tensor_elements_count(params_.shape_input_kv, params_.layout) * get_data_type_bytes_width(params_.dt)); + assert(params_.shape_input_q.get_dims_count() == 3); + assert(params_.shape_input_kv.get_dims_count() == 5); + assert(!input_data_stacked_q.empty()); + assert(!input_data_stacked_kv.empty()); + } + else { assert(false && "Not supported MHA type!"); } @@ -230,57 +342,114 @@ class MhaBaseDispatcher: public NodeDispatcher if (params_.dt == DataType::eFp32) { - randomize_linear_container_float(random_generator, uniform_distribution, input_data_); + randomize_linear_container_float(random_generator, uniform_distribution, input_data_stacked_qkv); + if (params_.type == MhaType::MhaType_Q_KV) + { + randomize_linear_container_float(random_generator, uniform_distribution, input_data_stacked_q); + randomize_linear_container_float(random_generator, uniform_distribution, input_data_stacked_kv); + } } else if (params_.dt == DataType::eFp16) { - randomize_linear_container_half(random_generator, uniform_distribution, input_data_); + randomize_linear_container_half(random_generator, uniform_distribution, input_data_stacked_qkv); + //fill_with_constant_linear_container_half(input_data_stacked_qkv, DirectX::PackedVector::XMConvertFloatToHalf(0.1f)); + if (params_.type == MhaType::MhaType_Q_KV) + { + randomize_linear_container_half(random_generator, uniform_distribution, input_data_stacked_q); + //fill_with_constant_linear_container_half(input_data_stacked_q, DirectX::PackedVector::XMConvertFloatToHalf(0.1f)); + //fill_with_constant_linear_container_half(input_data_stacked_kv, DirectX::PackedVector::XMConvertFloatToHalf(0.1f)); + randomize_linear_container_half(random_generator, uniform_distribution, input_data_stacked_kv); + } } else { assert(false && "Unsupported data type in convolution dispatcher!"); } - const auto tensor_input_bytes_width = input_data_.size(); - + + const auto tensor_input_bytes_width_qkv = input_data_stacked_qkv.size(); + const auto tensor_input_bytes_width_q = input_data_stacked_q.size(); + const auto tensor_input_bytes_width_kv = input_data_stacked_kv.size(); + const auto out_shape = get_shape_output(); - const auto tensor_out_bytes_width = get_tensor_elements_count(out_shape, params_.layout) * get_data_type_bytes_width(params_.dt); - - upload_buffer_ = create_buffer(d3d12_device_, tensor_input_bytes_width, - D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ); - input_buffer_ = create_buffer(d3d12_device, tensor_input_bytes_width, - D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); - - output_buffer_ = create_buffer(d3d12_device, tensor_out_bytes_width, - D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); - - // copy data into buffer - std::byte* upload_mapped_ptr = nullptr; - upload_buffer_->Map(0, nullptr, reinterpret_cast(&upload_mapped_ptr)); - std::size_t memcopy_offset = 0; - std::memcpy(upload_mapped_ptr, input_data_.data(), tensor_input_bytes_width); - memcopy_offset += tensor_input_bytes_width; - - // unmap memory - upload_buffer_->Unmap(0, nullptr); - - memcopy_offset = 0; - cmd_list->CopyBufferRegion(input_buffer_.Get(), 0, upload_buffer_.Get(), 0, tensor_input_bytes_width); - memcopy_offset += tensor_input_bytes_width; - - std::vector barriers; - barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(input_buffer_.Get(), - D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_UNORDERED_ACCESS)); - - cmd_list->ResourceBarrier(static_cast(barriers.size()), barriers.data()); + const auto tensor_out_bytes_width = get_tensor_elements_count(out_shape, DataLayout::eNHW) * get_data_type_bytes_width(params_.dt); + + if (params_.type == MhaType::MhaType_QKV) + { + upload_buffer_ = create_buffer(d3d12_device_, tensor_input_bytes_width_qkv, + D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ); + input_buffer_qkv = create_buffer(d3d12_device, tensor_input_bytes_width_qkv, + D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + + output_buffer_ = create_buffer(d3d12_device, tensor_out_bytes_width, + D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + + // copy data into buffer + std::byte* upload_mapped_ptr = nullptr; + upload_buffer_->Map(0, nullptr, reinterpret_cast(&upload_mapped_ptr)); + std::size_t memcopy_offset = 0; + std::memcpy(upload_mapped_ptr, input_data_stacked_qkv.data(), tensor_input_bytes_width_qkv); + memcopy_offset += tensor_input_bytes_width_qkv; + + // unmap memory + upload_buffer_->Unmap(0, nullptr); + + memcopy_offset = 0; + cmd_list->CopyBufferRegion(input_buffer_qkv.Get(), 0, upload_buffer_.Get(), 0, tensor_input_bytes_width_qkv); + memcopy_offset += tensor_input_bytes_width_qkv; + + std::vector barriers; + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(input_buffer_qkv.Get(), + D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_UNORDERED_ACCESS)); + + cmd_list->ResourceBarrier(static_cast(barriers.size()), barriers.data()); + } + if (params_.type == MhaType::MhaType_Q_KV) + { + upload_buffer_ = create_buffer(d3d12_device_, tensor_input_bytes_width_q + tensor_input_bytes_width_kv, + D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ); + input_buffer_q = create_buffer(d3d12_device, tensor_input_bytes_width_q, + D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + input_buffer_kv = create_buffer(d3d12_device, tensor_input_bytes_width_kv, + D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + + output_buffer_ = create_buffer(d3d12_device, tensor_out_bytes_width, + D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + + // copy data into buffer + std::byte* upload_mapped_ptr = nullptr; + upload_buffer_->Map(0, nullptr, reinterpret_cast(&upload_mapped_ptr)); + std::size_t memcopy_offset = 0; + std::memcpy(upload_mapped_ptr, input_data_stacked_q.data(), tensor_input_bytes_width_q); + memcopy_offset += tensor_input_bytes_width_q; + std::memcpy(upload_mapped_ptr + memcopy_offset, input_data_stacked_kv.data(), tensor_input_bytes_width_kv); + memcopy_offset += tensor_input_bytes_width_kv; + + // unmap memory + upload_buffer_->Unmap(0, nullptr); + + memcopy_offset = 0; + cmd_list->CopyBufferRegion(input_buffer_q.Get(), 0, upload_buffer_.Get(), 0, tensor_input_bytes_width_q); + memcopy_offset += tensor_input_bytes_width_q; + cmd_list->CopyBufferRegion(input_buffer_kv.Get(), 0, upload_buffer_.Get(), memcopy_offset, tensor_input_bytes_width_kv); + memcopy_offset += tensor_input_bytes_width_kv; + + std::vector barriers; + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(input_buffer_q.Get(), + D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_UNORDERED_ACCESS)); + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(input_buffer_kv.Get(), + D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_UNORDERED_ACCESS)); + + cmd_list->ResourceBarrier(static_cast(barriers.size()), barriers.data()); + } } virtual ConformanceResult validate_conformance(ID3D12CommandQueue* command_queue, ID3D12CommandAllocator* command_allocator, ID3D12GraphicsCommandList* command_list, bool print_mismatches, std::size_t reference_dispatch_iterations) { const auto out_shape = get_shape_output(); - const auto tensor_out_bytes_width = get_tensor_elements_count(out_shape, params_.layout) * get_data_type_bytes_width(params_.dt); + const auto tensor_out_bytes_width = get_tensor_elements_count(out_shape, DataLayout::eNHW) * get_data_type_bytes_width(params_.dt); // readback data and validate auto readback_buffer = create_buffer(d3d12_device_, tensor_out_bytes_width, D3D12_HEAP_TYPE_READBACK, D3D12_RESOURCE_STATE_COPY_DEST); @@ -304,7 +473,7 @@ class MhaBaseDispatcher: public NodeDispatcher command_list->ResourceBarrier(1, &readback_output_barrirer); gpu_op::Mha mha_ref(params_.type, to_dml_data_type(params_.dt), to_dml_tensor_policy(params_.layout), - params_.shape_input, get_shape_output(), + params_.shape_input_qkv, params_.shape_input_q, params_.shape_input_kv, get_shape_output(), dml_device_, d3d12_device_, true /*disable metacommand*/); // bind descriptor heap @@ -317,7 +486,7 @@ class MhaBaseDispatcher: public NodeDispatcher close_execute_reset_wait(d3d12_device_, command_queue, command_allocator, command_list); command_list->SetDescriptorHeaps(1, d3d12_descriptor_heaps); - mha_ref.record_execute(dml_cmd_recorder_, command_list, output_buffer_.Get(), input_buffer_.Get()); + mha_ref.record_execute(dml_cmd_recorder_, command_list, output_buffer_.Get(), input_buffer_qkv.Get(), input_buffer_q.Get(), input_buffer_kv.Get()); close_execute_reset_wait(d3d12_device_, command_queue, command_allocator, command_list); readback_output_barrirer = CD3DX12_RESOURCE_BARRIER::Transition(output_buffer_.Get(), @@ -336,11 +505,11 @@ class MhaBaseDispatcher: public NodeDispatcher if (params_.dt == DataType::eFp32) { - return run_conformance_check(data_out, ref_untyped_result, 0.05f, print_mismatches); + return run_conformance_check(data_out, ref_untyped_result, 0.005f, print_mismatches); } else if (params_.dt == DataType::eFp16) { - return run_conformance_check(data_out, ref_untyped_result, 0.05f, print_mismatches); + return run_conformance_check(data_out, ref_untyped_result, 0.005f, print_mismatches); } assert(false && "Unsupported output data type!"); ConformanceResult ret{}; @@ -359,7 +528,16 @@ class MhaBaseDispatcher: public NodeDispatcher std::uint32_t get_batch() const { - return params_.shape_input.n; + if (params_.type == MhaType::MhaType_QKV) + { + return params_.shape_input_qkv.n; + } + if (params_.type == MhaType::MhaType_Q_KV) + { + return params_.shape_input_q.n; + } + assert(false && "Not supported"); + return 0; } @@ -367,7 +545,11 @@ class MhaBaseDispatcher: public NodeDispatcher { if (params_.type == MhaType::MhaType_QKV) { - return params_.shape_input.c; + return params_.shape_input_qkv.c; + } + if (params_.type == MhaType::MhaType_Q_KV) + { + return params_.shape_input_q.c; } assert(false && "Not supported"); return 0; @@ -376,11 +558,14 @@ class MhaBaseDispatcher: public NodeDispatcher std::uint32_t get_width() const { - if (params_.type == MhaType::MhaType_QKV) + if (params_.type == MhaType::MhaType_QKV) + { + return params_.shape_input_qkv.h * params_.shape_input_qkv.w; + } + if (params_.type == MhaType::MhaType_Q_KV) { - return params_.shape_input.d * params_.shape_input.w; + return params_.shape_input_q.h; } - assert(false && "Not supported"); return 0; } @@ -391,9 +576,13 @@ class MhaBaseDispatcher: public NodeDispatcher IDMLDevice* dml_device_; IDMLCommandRecorder* dml_cmd_recorder_; - std::vector input_data_; + std::vector input_data_stacked_qkv; + std::vector input_data_stacked_q; + std::vector input_data_stacked_kv; - ComPtr input_buffer_; + ComPtr input_buffer_qkv; + ComPtr input_buffer_q; + ComPtr input_buffer_kv; ComPtr output_buffer_; ComPtr upload_buffer_; @@ -405,7 +594,7 @@ class MhaDmlDispatcher : public MhaBaseDispatcher public: MhaDmlDispatcher(create_params_t&& params, ID3D12Device* d3d12_device, IDMLDevice* dml_device, IDMLCommandRecorder* dml_cmd_recorder, ID3D12GraphicsCommandList* cmd_list) : MhaBaseDispatcher(std::move(params), d3d12_device, dml_device, dml_cmd_recorder, cmd_list) - , mha_(params_.type, to_dml_data_type(params_.dt), to_dml_tensor_policy(params_.layout), params_.shape_input, get_shape_output(), + , mha_(params_.type, to_dml_data_type(params_.dt), to_dml_tensor_policy(params_.layout), params_.shape_input_qkv, params_.shape_input_q, params_.shape_input_kv, get_shape_output(), dml_device, d3d12_device, false) { @@ -425,8 +614,7 @@ class MhaDmlDispatcher : public MhaBaseDispatcher void execute(ID3D12GraphicsCommandList* cmd_list) { - mha_.record_execute(dml_cmd_recorder_, cmd_list, - output_buffer_.Get(), input_buffer_.Get()); + mha_.record_execute(dml_cmd_recorder_, cmd_list, output_buffer_.Get(), input_buffer_qkv.Get(), input_buffer_q.Get(), input_buffer_kv.Get()); } private: gpu_op::Mha mha_; diff --git a/tools/common_lib/src/tensor_shape.h b/tools/common_lib/src/tensor_shape.h index a0043ac..8082513 100644 --- a/tools/common_lib/src/tensor_shape.h +++ b/tools/common_lib/src/tensor_shape.h @@ -9,11 +9,18 @@ struct TensorShape TensorShape() = default; + TensorShape(std::uint32_t n, std::uint32_t h, std::uint32_t w) + : n(n), h(h), w(w) + { + } TensorShape(std::uint32_t n, std::uint32_t c, std::uint32_t h, std::uint32_t w) : n(n), c(c), h(h), w(w) { } - + TensorShape(std::uint32_t n, std::uint32_t c, std::uint32_t d, std::uint32_t h, std::uint32_t w) + : n(n), c(c), d(d), h(h), w(w) + { + } TensorShape(std::span in_v) { assert(!(in_v.size() < 3 || in_v.size() > 5) && "Not supported shape!");