diff --git a/custom_ops/xpu_ops/src/ops/get_infer_param.cc b/custom_ops/xpu_ops/src/ops/get_infer_param.cc index 2c1e43b7c64..677e346d45a 100644 --- a/custom_ops/xpu_ops/src/ops/get_infer_param.cc +++ b/custom_ops/xpu_ops/src/ops/get_infer_param.cc @@ -25,7 +25,7 @@ namespace api = baidu::xpu::api; void lod_to_slot_mapping(api::Context* xpu_ctx, paddle::Place place, - const std::vector& block_table, + const paddle::Tensor& block_table_xpu, const std::vector& kv_seq_lod, const std::vector& start_tokens, const std::vector& real_batch, @@ -35,10 +35,16 @@ void lod_to_slot_mapping(api::Context* xpu_ctx, int32_t batch_size, int32_t max_num_blocks_per_seq, int32_t num_speculative_tokens) { - if (token_num <= 0) { + int32_t actual_token_num = kv_seq_lod[batch_size]; + if (token_num <= 0 || actual_token_num <= 0) { return; } - std::vector slot_mapping_vec(token_num, -1); + + int ret; + + std::vector block_table_idx_vec(actual_token_num); + std::vector seq_offset_vec(actual_token_num); + int32_t idx = 0; // For each Batch for (auto batch_ = 0; batch_ < batch_size; batch_++) { @@ -47,20 +53,56 @@ void lod_to_slot_mapping(api::Context* xpu_ctx, int32_t dst_batch_id = real_batch[batch_]; // for each token for (auto seq_ = seq_start; seq_ < seq_start + seq_len; seq_++) { - int32_t table_id = seq_ / block_size; - int32_t block_id = - block_table[dst_batch_id * max_num_blocks_per_seq + table_id]; - int32_t seq_offset = seq_ % block_size; - int32_t dst_token_offset = block_id * block_size + seq_offset; - slot_mapping_vec[idx] = dst_token_offset; + block_table_idx_vec[idx] = + seq_ / block_size + dst_batch_id * max_num_blocks_per_seq; + seq_offset_vec[idx] = seq_ % block_size; idx++; } } - int ret = api::do_host2device(xpu_ctx, - slot_mapping_vec.data(), - slot_mapping, - token_num * sizeof(int32_t)); + auto block_table_idx = + paddle::empty({actual_token_num}, paddle::DataType::INT32, place); + auto seq_offset = + paddle::empty({actual_token_num}, paddle::DataType::INT32, place); + ret = api::do_host2device(xpu_ctx, + block_table_idx_vec.data(), + block_table_idx.data(), + actual_token_num * sizeof(int32_t)); PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed."); + ret = api::do_host2device(xpu_ctx, + seq_offset_vec.data(), + seq_offset.data(), + actual_token_num * sizeof(int32_t)); + PD_CHECK(ret == api::SUCCESS, "api::do_host2device failed."); + + // int32_t block_id = + // block_table[dst_batch_id * max_num_blocks_per_seq + table_id]; + auto block_id = + paddle::empty({actual_token_num}, paddle::DataType::INT32, place); + auto block_size_tensor = + paddle::full({1}, block_size, paddle::DataType::INT32, place); + ret = api::index_select(xpu_ctx, + block_table_xpu.data(), + block_table_idx.data(), + block_id.data(), + {block_table_xpu.numel()}, + actual_token_num, + 0); + PD_CHECK(ret == api::SUCCESS, "api::index_select failed."); + // int32_t dst_token_offset = block_id * block_size + seq_offset; + ret = api::broadcast_mul(xpu_ctx, + block_id.data(), + block_size_tensor.data(), + block_id.data(), + {actual_token_num}, + {1}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_mul failed."); + ret = api::broadcast_add(xpu_ctx, + block_id.data(), + seq_offset.data(), + slot_mapping, + {actual_token_num}, + {actual_token_num}); + PD_CHECK(ret == api::SUCCESS, "api::broadcast_add failed."); } std::vector GetInferParam( @@ -68,6 +110,28 @@ std::vector GetInferParam( const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& block_tables, + paddle::Tensor& encoder_batch_map, + paddle::Tensor& decoder_batch_map, + paddle::Tensor& encoder_batch_idx, + paddle::Tensor& decoder_batch_idx, + paddle::Tensor& encoder_seq_lod, + paddle::Tensor& decoder_seq_lod, + paddle::Tensor& encoder_kv_lod, + paddle::Tensor& prefix_len, + paddle::Tensor& decoder_context_len, + paddle::Tensor& decoder_context_len_cache, + paddle::Tensor& prefix_block_tables, + paddle::Tensor& encoder_batch_map_cpu, + paddle::Tensor& decoder_batch_map_cpu, + paddle::Tensor& encoder_batch_idx_cpu, + paddle::Tensor& decoder_batch_idx_cpu, + paddle::Tensor& encoder_seq_lod_cpu, + paddle::Tensor& decoder_seq_lod_cpu, + paddle::Tensor& encoder_kv_lod_cpu, + paddle::Tensor& prefix_len_cpu, + paddle::Tensor& decoder_context_len_cpu, + paddle::Tensor& decoder_context_len_cache_cpu, + paddle::Tensor& len_info_cpu, int block_size, int num_speculative_tokens) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); @@ -128,18 +192,18 @@ std::vector GetInferParam( if (seq_lens_encoder_vec[i] > 0) { enc_batch++; int seq_len = seq_lens_encoder_vec[i]; - int prefix_len = seq_lens_decoder_vec[i]; + int prefix_len_int = seq_lens_decoder_vec[i]; total_enc_len += seq_len; max_seq_len = std::max(max_seq_len, seq_len); - max_prefix_len = std::max(max_prefix_len, prefix_len); - max_kv_len = std::max(max_kv_len, seq_len + prefix_len); + max_prefix_len = std::max(max_prefix_len, prefix_len_int); + max_kv_len = std::max(max_kv_len, seq_len + prefix_len_int); encoder_batch_map_vec[enc_batch - 1] = i; encoder_batch_idx_vec[enc_batch - 1] = i - batch_offset; encoder_seq_lod_vec[enc_batch] = seq_len + encoder_seq_lod_vec[enc_batch - 1]; encoder_kv_lod_vec[enc_batch] = - seq_len + prefix_len + encoder_kv_lod_vec[enc_batch - 1]; - prefix_len_vec[enc_batch - 1] = prefix_len; + seq_len + prefix_len_int + encoder_kv_lod_vec[enc_batch - 1]; + prefix_len_vec[enc_batch - 1] = prefix_len_int; } else if (seq_lens_decoder_vec[i] > 0 && seq_lens_this_time_vec[i] > 0) { dec_batch++; max_dec_len = std::max(max_dec_len, seq_lens_this_time_vec[i]); @@ -186,42 +250,6 @@ std::vector GetInferParam( prefix_block_num_per_seq = -1; } - auto encoder_batch_map = paddle::empty({encoder_batch_map_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto decoder_batch_map = paddle::empty({decoder_batch_map_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto encoder_batch_idx = paddle::empty({encoder_batch_idx_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto decoder_batch_idx = paddle::empty({decoder_batch_idx_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto encoder_seq_lod = paddle::empty({encoder_seq_lod_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto decoder_seq_lod = paddle::empty({decoder_seq_lod_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto encoder_kv_lod = paddle::empty({encoder_kv_lod_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto prefix_len = paddle::empty({prefix_len_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto decoder_context_len = paddle::empty({decoder_context_len_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto decoder_context_len_cache = - paddle::empty({decoder_context_len_cache_vec.size()}, - seq_lens_encoder.type(), - seq_lens_encoder.place()); - auto prefix_block_tables = - paddle::empty({block_bs, block_num_per_seq}, // full size - seq_lens_encoder.type(), - seq_lens_encoder.place()); - // for store_paged_kv_cache of cudagraph mode // if slot_mapping is -1, store_paged_kv_cache will not write to kv cache paddle::Tensor slot_mapping_enc = paddle::full( @@ -232,72 +260,34 @@ std::vector GetInferParam( -1, paddle::DataType::INT32, seq_lens_decoder.place()); - if (FLAGS_encoder_splice || FLAGS_decoder_splice) { - std::vector block_tables_vec(block_bs * block_num_per_seq); - r = xpu_memcpy(block_tables_vec.data(), - block_tables.data(), - sizeof(int32_t) * block_bs * block_num_per_seq, - XPUMemcpyKind::XPU_DEVICE_TO_HOST); - if (FLAGS_encoder_splice) { - lod_to_slot_mapping(xpu_ctx->x_context(), - seq_lens_encoder.place(), - block_tables_vec, - encoder_seq_lod_vec, - prefix_len_vec, - encoder_batch_map_vec, - slot_mapping_enc.data(), - total_enc_len, - block_size, - enc_batch, - block_num_per_seq, - 0); - } - if (FLAGS_decoder_splice) { - lod_to_slot_mapping(xpu_ctx->x_context(), - seq_lens_decoder.place(), - block_tables_vec, - decoder_seq_lod_vec, - decoder_context_len_cache_vec, - decoder_batch_map_vec, - slot_mapping_dec.data(), - bsz * (1 + num_speculative_tokens), - block_size, - dec_batch, - block_num_per_seq, - num_speculative_tokens); - } + if (FLAGS_encoder_splice) { + lod_to_slot_mapping(xpu_ctx->x_context(), + seq_lens_encoder.place(), + block_tables, + encoder_seq_lod_vec, + prefix_len_vec, + encoder_batch_map_vec, + slot_mapping_enc.data(), + slot_mapping_enc.numel(), + block_size, + enc_batch, + block_num_per_seq, + 0); + } + if (FLAGS_decoder_splice) { + lod_to_slot_mapping(xpu_ctx->x_context(), + seq_lens_decoder.place(), + block_tables, + decoder_seq_lod_vec, + decoder_context_len_cache_vec, + decoder_batch_map_vec, + slot_mapping_dec.data(), + slot_mapping_dec.numel(), + block_size, + dec_batch, + block_num_per_seq, + num_speculative_tokens); } - - auto encoder_batch_map_cpu = paddle::empty({encoder_batch_map_vec.size()}, - seq_lens_encoder.type(), - paddle::CPUPlace()); - auto decoder_batch_map_cpu = paddle::empty({decoder_batch_map_vec.size()}, - seq_lens_encoder.type(), - paddle::CPUPlace()); - auto encoder_batch_idx_cpu = paddle::empty({encoder_batch_idx_vec.size()}, - seq_lens_encoder.type(), - paddle::CPUPlace()); - auto decoder_batch_idx_cpu = paddle::empty({decoder_batch_idx_vec.size()}, - seq_lens_encoder.type(), - paddle::CPUPlace()); - auto encoder_seq_lod_cpu = paddle::empty({encoder_seq_lod_vec.size()}, - seq_lens_encoder.type(), - paddle::CPUPlace()); - auto decoder_seq_lod_cpu = paddle::empty({decoder_seq_lod_vec.size()}, - seq_lens_encoder.type(), - paddle::CPUPlace()); - - auto encoder_kv_lod_cpu = paddle::empty( - {encoder_kv_lod_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace()); - auto prefix_len_cpu = paddle::empty( - {prefix_len_vec.size()}, seq_lens_encoder.type(), paddle::CPUPlace()); - auto decoder_context_len_cpu = paddle::empty({decoder_context_len_vec.size()}, - seq_lens_encoder.type(), - paddle::CPUPlace()); - auto decoder_context_len_cache_cpu = - paddle::empty({decoder_context_len_cache_vec.size()}, - seq_lens_encoder.type(), - paddle::CPUPlace()); ret = api::do_host2device( xpu_ctx->x_context(), @@ -400,65 +390,25 @@ std::vector GetInferParam( max_kv_len, prefix_block_num_per_seq, max_dec_len}; - auto len_info_cpu = - paddle::empty({7}, seq_lens_encoder.type(), paddle::CPUPlace()); + std::memcpy(len_info_cpu.data(), len_info_vec.data(), sizeof(int32_t) * len_info_vec.size()); - return {encoder_batch_map, - decoder_batch_map, - encoder_batch_idx, - decoder_batch_idx, - encoder_seq_lod, - decoder_seq_lod, - encoder_kv_lod, - prefix_len, - decoder_context_len, - decoder_context_len_cache, - prefix_block_tables, - encoder_batch_map_cpu, - decoder_batch_map_cpu, - encoder_batch_idx_cpu, - decoder_batch_idx_cpu, - encoder_seq_lod_cpu, - decoder_seq_lod_cpu, - encoder_kv_lod_cpu, - prefix_len_cpu, - decoder_context_len_cpu, - decoder_context_len_cache_cpu, - len_info_cpu, - slot_mapping_enc, - slot_mapping_dec}; + return {slot_mapping_enc, slot_mapping_dec}; } std::vector> GetInferParamInferShape( const std::vector& seq_lens_encoder_shape, const std::vector& seq_lens_decoder_shape, const std::vector& seq_lens_this_time_shape, - const std::vector& block_tables_shape) { - return {seq_lens_encoder_shape, - seq_lens_encoder_shape, - seq_lens_encoder_shape, - seq_lens_encoder_shape, - {seq_lens_encoder_shape[0] + 1}, - {seq_lens_encoder_shape[0] + 1}, - {seq_lens_encoder_shape[0] + 1}, - seq_lens_encoder_shape, - seq_lens_encoder_shape, - seq_lens_encoder_shape, - block_tables_shape, - seq_lens_encoder_shape, - seq_lens_encoder_shape, - seq_lens_encoder_shape, - seq_lens_encoder_shape, - {seq_lens_encoder_shape[0] + 1}, - {seq_lens_encoder_shape[0] + 1}, - {seq_lens_encoder_shape[0] + 1}, - seq_lens_encoder_shape, - seq_lens_encoder_shape, - seq_lens_encoder_shape, - {7}}; + const std::vector& block_tables_shape, + int num_speculative_tokens) { + // Return shapes for slot_mapping_enc and slot_mapping_dec + // slot_mapping_enc shape depends on encoder token count (unknown at shape + // inference time) slot_mapping_dec shape depends on batch size and + // speculative token count + return {{-1}, {seq_lens_encoder_shape[0] * (1 + num_speculative_tokens)}}; } std::vector GetInferParamInferDtype( @@ -466,46 +416,38 @@ std::vector GetInferParamInferDtype( const paddle::DataType& seq_lens_decoder_dtype, const paddle::DataType& seq_lens_this_time_dtype, const paddle::DataType& block_tables_dtype) { - return { - seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype, - seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype, - seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype, - seq_lens_encoder_dtype, block_tables_dtype, seq_lens_encoder_dtype, - seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype, - seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype, - seq_lens_encoder_dtype, seq_lens_encoder_dtype, seq_lens_encoder_dtype, - seq_lens_encoder_dtype}; + // Return dtypes for slot_mapping_enc and slot_mapping_dec (both INT32) + return {paddle::DataType::INT32, paddle::DataType::INT32}; } PD_BUILD_OP(get_infer_param) .Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", - "block_tables"}) - .Outputs({"encoder_batch_map", - "decoder_batch_map", - "encoder_batch_idx", - "decoder_batch_idx", - "encoder_seq_lod", - "decoder_seq_lod", - "encoder_kv_lod", - "prefix_len", - "decoder_context_len", - "decoder_context_len_cache", - "prefix_block_tables", - "encoder_batch_map_cpu", - "decoder_batch_map_cpu", - "encoder_batch_idx_cpu", - "decoder_batch_idx_cpu", - "encoder_seq_lod_cpu", - "decoder_seq_lod_cpu", - "encoder_kv_lod_cpu", - "prefix_len_cpu", - "decoder_context_len_cpu", - "decoder_context_len_cache_cpu", - "len_info_cpu", - "slot_mapping_enc", - "slot_mapping_dec"}) + "block_tables", + "encoder_batch_map", + "decoder_batch_map", + "encoder_batch_idx", + "decoder_batch_idx", + "encoder_seq_lod", + "decoder_seq_lod", + "encoder_kv_lod", + "prefix_len", + "decoder_context_len", + "decoder_context_len_cache", + "prefix_block_tables", + "encoder_batch_map_cpu", + "decoder_batch_map_cpu", + "encoder_batch_idx_cpu", + "decoder_batch_idx_cpu", + "encoder_seq_lod_cpu", + "decoder_seq_lod_cpu", + "encoder_kv_lod_cpu", + "prefix_len_cpu", + "decoder_context_len_cpu", + "decoder_context_len_cache_cpu", + "len_info_cpu"}) + .Outputs({"slot_mapping_enc", "slot_mapping_dec"}) .SetKernelFn(PD_KERNEL(GetInferParam)) .Attrs({"block_size: int", "num_speculative_tokens: int"}) .SetInferShapeFn(PD_INFER_SHAPE(GetInferParamInferShape)) diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 887a9ac95a1..815a2875a1a 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -478,6 +478,28 @@ std::vector GetInferParam( const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& block_tables, + paddle::Tensor& encoder_batch_map, + paddle::Tensor& decoder_batch_map, + paddle::Tensor& encoder_batch_idx, + paddle::Tensor& decoder_batch_idx, + paddle::Tensor& encoder_seq_lod, + paddle::Tensor& decoder_seq_lod, + paddle::Tensor& encoder_kv_lod, + paddle::Tensor& prefix_len, + paddle::Tensor& decoder_context_len, + paddle::Tensor& decoder_context_len_cache, + paddle::Tensor& prefix_block_tables, + paddle::Tensor& encoder_batch_map_cpu, + paddle::Tensor& decoder_batch_map_cpu, + paddle::Tensor& encoder_batch_idx_cpu, + paddle::Tensor& decoder_batch_idx_cpu, + paddle::Tensor& encoder_seq_lod_cpu, + paddle::Tensor& decoder_seq_lod_cpu, + paddle::Tensor& encoder_kv_lod_cpu, + paddle::Tensor& prefix_len_cpu, + paddle::Tensor& decoder_context_len_cpu, + paddle::Tensor& decoder_context_len_cache_cpu, + paddle::Tensor& len_info_cpu, int block_size, int num_speculative_tokens); @@ -1052,6 +1074,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("seq_lens_decoder"), py::arg("seq_lens_this_time"), py::arg("block_tables"), + py::arg("encoder_batch_map"), + py::arg("decoder_batch_map"), + py::arg("encoder_batch_idx"), + py::arg("decoder_batch_idx"), + py::arg("encoder_seq_lod"), + py::arg("decoder_seq_lod"), + py::arg("encoder_kv_lod"), + py::arg("prefix_len"), + py::arg("decoder_context_len"), + py::arg("decoder_context_len_cache"), + py::arg("prefix_block_tables"), + py::arg("encoder_batch_map_cpu"), + py::arg("decoder_batch_map_cpu"), + py::arg("encoder_batch_idx_cpu"), + py::arg("decoder_batch_idx_cpu"), + py::arg("encoder_seq_lod_cpu"), + py::arg("decoder_seq_lod_cpu"), + py::arg("encoder_kv_lod_cpu"), + py::arg("prefix_len_cpu"), + py::arg("decoder_context_len_cpu"), + py::arg("decoder_context_len_cache_cpu"), + py::arg("len_info_cpu"), py::arg("block_size"), py::arg("num_speculative_tokens"), "Get infer parameters for block attention in XPU"); diff --git a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py index bc074242b4e..f9de20b8875 100644 --- a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py +++ b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py @@ -16,6 +16,7 @@ import numpy as np import paddle +from utils import init_inplace_tensor from fastdeploy.model_executor.ops.xpu import ( adjust_batch, @@ -33,11 +34,8 @@ def _run_test_base(seq_lens_this_time_data, is_speculative): seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32") bsz = seq_lens_this_time.shape[0] - cum_offsets = paddle.zeros(bsz, dtype="int32") block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8)) - infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64) - ( encoder_batch_map, decoder_batch_map, @@ -45,23 +43,56 @@ def _run_test_base(seq_lens_this_time_data, is_speculative): decoder_batch_idx, encoder_seq_lod, decoder_seq_lod, - _, - _, - _, - _, - _, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + ) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape) + ( + slot_mapping_enc, + slot_mapping_dec, + ) = get_infer_param( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_table, + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, encoder_batch_map_cpu, decoder_batch_map_cpu, encoder_batch_idx_cpu, decoder_batch_idx_cpu, encoder_seq_lod_cpu, decoder_seq_lod_cpu, - _, - _, - _, - _, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, len_info_cpu, - ) = infer_params + 64, + 0, + ) token_num = seq_lens_this_time.sum().cpu().item() hidden_dim = 8192 @@ -72,7 +103,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative): # 测试 adjust_batch adjusted_output = adjust_batch( input_tensor, - cum_offsets, encoder_seq_lod, decoder_seq_lod, encoder_batch_idx, @@ -88,7 +118,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative): adjusted_output_cpu = adjust_batch( input_tensor.cpu(), - cum_offsets, encoder_seq_lod, decoder_seq_lod, encoder_batch_idx, @@ -110,7 +139,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative): # 测试 gather_next_token gather_out = gather_next_token( adjusted_output, - cum_offsets, encoder_seq_lod, decoder_seq_lod, encoder_batch_map, @@ -126,7 +154,6 @@ def _run_test_base(seq_lens_this_time_data, is_speculative): gather_out_cpu = gather_next_token( adjusted_output.cpu(), - cum_offsets, encoder_seq_lod, decoder_seq_lod, encoder_batch_map, diff --git a/custom_ops/xpu_ops/test/test_adjust_batch_and_recover_batch_sequence.py b/custom_ops/xpu_ops/test/test_adjust_batch_and_recover_batch_sequence.py index 9a1463c4b7d..5e5f55df42b 100644 --- a/custom_ops/xpu_ops/test/test_adjust_batch_and_recover_batch_sequence.py +++ b/custom_ops/xpu_ops/test/test_adjust_batch_and_recover_batch_sequence.py @@ -16,6 +16,7 @@ import numpy as np import paddle +from utils import init_inplace_tensor from fastdeploy.model_executor.ops.xpu import ( adjust_batch, @@ -33,8 +34,6 @@ def _run_test_base(seq_lens_this_time_data): cum_offsets = paddle.zeros(bsz, dtype="int32") block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8)) - infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64) - ( encoder_batch_map, decoder_batch_map, @@ -42,23 +41,56 @@ def _run_test_base(seq_lens_this_time_data): decoder_batch_idx, encoder_seq_lod, decoder_seq_lod, - _, - _, - _, - _, - _, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + ) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape) + ( + slot_mapping_enc, + slot_mapping_dec, + ) = get_infer_param( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_table, + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, encoder_batch_map_cpu, decoder_batch_map_cpu, encoder_batch_idx_cpu, decoder_batch_idx_cpu, encoder_seq_lod_cpu, decoder_seq_lod_cpu, - _, - _, - _, - _, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, len_info_cpu, - ) = infer_params + 64, + 0, + ) token_num = seq_lens_this_time.sum().cpu().item() hidden_dim = 8192 @@ -68,7 +100,6 @@ def _run_test_base(seq_lens_this_time_data): # test adjust_batch adjusted_output = adjust_batch( input_tensor, - cum_offsets, encoder_seq_lod, decoder_seq_lod, encoder_batch_idx, @@ -84,7 +115,6 @@ def _run_test_base(seq_lens_this_time_data): adjusted_output_cpu = adjust_batch( input_tensor.cpu(), - cum_offsets, encoder_seq_lod, decoder_seq_lod, encoder_batch_idx, diff --git a/custom_ops/xpu_ops/test/test_block_attn.py b/custom_ops/xpu_ops/test/test_block_attn.py index 51b6a7b9db3..f9c3be92562 100644 --- a/custom_ops/xpu_ops/test/test_block_attn.py +++ b/custom_ops/xpu_ops/test/test_block_attn.py @@ -16,6 +16,7 @@ import numpy as np import paddle +from utils import init_inplace_tensor # block_attn_fused is deprecated and should be removed in the future from fastdeploy.model_executor.ops.xpu import ( @@ -76,6 +77,7 @@ def run_prefix_cache_block_attn( # prefix cache block attn seq_lens_encoder = paddle.to_tensor([seq_len - hit_prefix_len, 0, 0, 0, 0], dtype="int32") seq_lens_decoder = paddle.to_tensor([hit_prefix_len, 0, 0, 0, 0], dtype="int32") + ( encoder_batch_map, decoder_batch_map, @@ -99,11 +101,40 @@ def run_prefix_cache_block_attn( decoder_context_len_cpu, decoder_context_len_cache_cpu, len_info_cpu, + ) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape) + ( slot_mapping_enc, slot_mapping_dec, ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens - ) # block_size + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_tables, + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + 64, + num_speculative_tokens, + ) qkv_prefix = qkv[hit_prefix_len:] attn_out_prefix_cache = block_attn_func( qkv_prefix, @@ -194,6 +225,7 @@ def run_block_attn( seq_lens_this_time = paddle.to_tensor([seq_len, 0, 0, 0, 0], dtype="int32") block_tables = paddle.arange(0, block_batch * max_block_per_seq, dtype="int32") block_tables = block_tables.reshape((block_batch, max_block_per_seq)) + ( encoder_batch_map, decoder_batch_map, @@ -217,10 +249,39 @@ def run_block_attn( decoder_context_len_cpu, decoder_context_len_cache_cpu, len_info_cpu, + ) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape) + ( slot_mapping_enc, slot_mapping_dec, ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, num_speculative_tokens + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_tables, + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + 64, + num_speculative_tokens, ) qkv = paddle.uniform( shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], dtype="bfloat16", min=-1.0, max=1.0, seed=seed diff --git a/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py b/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py index c15a5f4e4af..f877b09d304 100644 --- a/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py +++ b/custom_ops/xpu_ops/test/test_block_attn_prefix_cache.py @@ -14,6 +14,7 @@ import numpy as np import paddle +from utils import init_inplace_tensor from fastdeploy.model_executor.ops.xpu import block_attn_fused, get_infer_param @@ -24,6 +25,7 @@ block_batch = 5 max_block_per_seq = 128 block_size = 64 +num_speculative_tokens = 0 seq_lens_encoder = paddle.to_tensor([128, 0, 0, 0, 0], dtype="int32") seq_lens_decoder = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32") @@ -53,11 +55,40 @@ decoder_context_len_cpu, decoder_context_len_cache_cpu, len_info_cpu, +) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape) +( slot_mapping_enc, slot_mapping_dec, ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0 -) # block_size + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_tables, + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + 64, + num_speculative_tokens, +) qkv = paddle.uniform( shape=[seq_len, (head_num + 2 * kv_head_num) * head_dim], @@ -247,11 +278,40 @@ decoder_context_len_cpu, decoder_context_len_cache_cpu, len_info_cpu, +) = init_inplace_tensor(seq_lens_encoder.shape[0], block_tables.shape) +( slot_mapping_enc, slot_mapping_dec, ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_tables, 64, 0 -) # block_size + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_tables, + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + 64, + num_speculative_tokens, +) qkv_prefix = qkv[hit_prefix_len:] attn_out_prefix_cache = block_attn_fused( diff --git a/custom_ops/xpu_ops/test/test_get_infer_param.py b/custom_ops/xpu_ops/test/test_get_infer_param.py index f1b99239530..0dc67bebd94 100755 --- a/custom_ops/xpu_ops/test/test_get_infer_param.py +++ b/custom_ops/xpu_ops/test/test_get_infer_param.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle +from utils import init_inplace_tensor from fastdeploy.model_executor.ops.xpu import get_infer_param @@ -21,6 +22,7 @@ seq_lens_this_time = paddle.to_tensor([100, 1, 0, 1, 300], dtype="int32") block_table = paddle.arange(0, 40, dtype="int32") block_table = block_table.reshape((5, 8)) + ( encoder_batch_map, decoder_batch_map, @@ -44,9 +46,40 @@ decoder_context_len_cpu, decoder_context_len_cache_cpu, len_info_cpu, +) = init_inplace_tensor(seq_lens_encoder.shape[0], block_table.shape) +( + slot_mapping_enc, + slot_mapping_dec, ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64 -) # block_size + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_table, + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + 64, + 0, +) print("block_table", block_table) print("encoder_batch_map", encoder_batch_map) # [0, 4, 0, 0, 0] diff --git a/custom_ops/xpu_ops/test/utils.py b/custom_ops/xpu_ops/test/utils.py new file mode 100644 index 00000000000..5df8ba4af70 --- /dev/null +++ b/custom_ops/xpu_ops/test/utils.py @@ -0,0 +1,54 @@ +import paddle + + +def init_inplace_tensor(bsz, block_tables_shape): + encoder_batch_map = paddle.empty(bsz, dtype="int32") + decoder_batch_map = paddle.empty(bsz, dtype="int32") + encoder_batch_idx = paddle.empty(bsz, dtype="int32") + decoder_batch_idx = paddle.empty(bsz, dtype="int32") + encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32") + decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32") + encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32") + prefix_len = paddle.empty(bsz, dtype="int32") + decoder_context_len = paddle.empty(bsz, dtype="int32") + decoder_context_len_cache = paddle.empty(bsz, dtype="int32") + + prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32") + + encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu") + decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu") + encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu") + prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + + len_info_cpu = paddle.empty(7, dtype="int32", device="cpu") + + return ( + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, + ) diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 509f774ea83..516344a17f4 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -289,6 +289,33 @@ class XPUForwardMeta(ForwardMeta): # slot_mapping_dec: Optional[paddle.Tensor] = None + def init_inplace_tensor(self, bsz, block_tables_shape): + self.encoder_batch_map = paddle.empty(bsz, dtype="int32") + self.decoder_batch_map = paddle.empty(bsz, dtype="int32") + self.encoder_batch_idx = paddle.empty(bsz, dtype="int32") + self.decoder_batch_idx = paddle.empty(bsz, dtype="int32") + self.encoder_seq_lod = paddle.empty(bsz + 1, dtype="int32") + self.decoder_seq_lod = paddle.empty(bsz + 1, dtype="int32") + self.encoder_kv_lod = paddle.empty(bsz + 1, dtype="int32") + self.prefix_len = paddle.empty(bsz, dtype="int32") + self.decoder_context_len = paddle.empty(bsz, dtype="int32") + self.decoder_context_len_cache = paddle.empty(bsz, dtype="int32") + + self.prefix_block_tables = paddle.empty(block_tables_shape, dtype="int32") + + self.encoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + self.decoder_batch_map_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + self.encoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + self.decoder_batch_idx_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + self.encoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu") + self.decoder_seq_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu") + self.encoder_kv_lod_cpu = paddle.empty(bsz + 1, dtype="int32", device="cpu") + self.prefix_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + self.decoder_context_len_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + self.decoder_context_len_cache_cpu = paddle.empty(bsz, dtype="int32", device="cpu") + + self.len_info_cpu = paddle.empty(7, dtype="int32", device="cpu") + def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None): """ Synchronize attributes from another XPUForwardMeta object diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index 523a08d0c11..f2f2c157e21 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -158,70 +158,113 @@ def xpu_pre_process( share_inputs["cu_seqlens_q"] = cu_seqlens_q share_inputs["cu_seqlens_k"] = cu_seqlens_k - xpu_forward_meta = XPUForwardMeta( - ids_remove_padding=share_inputs["ids_remove_padding"], - rotary_embs=share_inputs["rope_emb"], - attn_backend=None, - seq_lens_encoder=share_inputs["seq_lens_encoder"], - seq_lens_decoder=share_inputs["seq_lens_decoder"], - seq_lens_this_time=share_inputs["seq_lens_this_time"], - batch_id_per_token=share_inputs["batch_id_per_token"], - cu_seqlens_q=share_inputs["cu_seqlens_q"], - cu_seqlens_k=share_inputs["cu_seqlens_k"], - block_tables=share_inputs["block_tables"], - caches=share_inputs["caches"], - max_num_seqs=share_inputs["seq_lens_this_time"].shape[0], - is_speculative=use_speculate_method, - ) + if use_cudagraph and forward_meta is not None: + forward_meta.ids_remove_padding.copy_(share_inputs["ids_remove_padding"], False) + forward_meta.rotary_embs.copy_(share_inputs["rope_emb"], False) + forward_meta.attn_backend = None + forward_meta.seq_lens_encoder.copy_(share_inputs["seq_lens_encoder"], False) + forward_meta.seq_lens_decoder.copy_(share_inputs["seq_lens_decoder"], False) + forward_meta.seq_lens_this_time.copy_(share_inputs["seq_lens_this_time"], False) + forward_meta.batch_id_per_token.copy_(share_inputs["batch_id_per_token"], False) + forward_meta.cu_seqlens_q.copy_(share_inputs["cu_seqlens_q"], False) + forward_meta.cu_seqlens_k.copy_(share_inputs["cu_seqlens_k"], False) + forward_meta.block_tables.copy_(share_inputs["block_tables"], False) + forward_meta.caches = share_inputs["caches"] + forward_meta.max_num_seqs = share_inputs["seq_lens_this_time"].shape[0] + forward_meta.is_speculative = use_speculate_method + + xpu_forward_meta = forward_meta + else: + xpu_forward_meta = XPUForwardMeta( + ids_remove_padding=share_inputs["ids_remove_padding"], + rotary_embs=share_inputs["rope_emb"], + attn_backend=None, + seq_lens_encoder=share_inputs["seq_lens_encoder"], + seq_lens_decoder=share_inputs["seq_lens_decoder"], + seq_lens_this_time=share_inputs["seq_lens_this_time"], + batch_id_per_token=share_inputs["batch_id_per_token"], + cu_seqlens_q=share_inputs["cu_seqlens_q"], + cu_seqlens_k=share_inputs["cu_seqlens_k"], + block_tables=share_inputs["block_tables"], + caches=share_inputs["caches"], + max_num_seqs=share_inputs["seq_lens_this_time"].shape[0], + is_speculative=use_speculate_method, + ) + xpu_forward_meta.init_inplace_tensor(seq_lens_encoder.shape[0], share_inputs["block_tables"].shape) + + block_tables = xpu_forward_meta.block_tables + + encoder_batch_map = xpu_forward_meta.encoder_batch_map + decoder_batch_map = xpu_forward_meta.decoder_batch_map + encoder_batch_idx = xpu_forward_meta.encoder_batch_idx + decoder_batch_idx = xpu_forward_meta.decoder_batch_idx + encoder_seq_lod = xpu_forward_meta.encoder_seq_lod + decoder_seq_lod = xpu_forward_meta.decoder_seq_lod + encoder_kv_lod = xpu_forward_meta.encoder_kv_lod + prefix_len = xpu_forward_meta.prefix_len + decoder_context_len = xpu_forward_meta.decoder_context_len + decoder_context_len_cache = xpu_forward_meta.decoder_context_len_cache + + prefix_block_tables = xpu_forward_meta.prefix_block_tables + + encoder_batch_map_cpu = xpu_forward_meta.encoder_batch_map_cpu + decoder_batch_map_cpu = xpu_forward_meta.decoder_batch_map_cpu + encoder_batch_idx_cpu = xpu_forward_meta.encoder_batch_idx_cpu + decoder_batch_idx_cpu = xpu_forward_meta.decoder_batch_idx_cpu + encoder_seq_lod_cpu = xpu_forward_meta.encoder_seq_lod_cpu + decoder_seq_lod_cpu = xpu_forward_meta.decoder_seq_lod_cpu + encoder_kv_lod_cpu = xpu_forward_meta.encoder_kv_lod_cpu + prefix_len_cpu = xpu_forward_meta.prefix_len_cpu + decoder_context_len_cpu = xpu_forward_meta.decoder_context_len_cpu + decoder_context_len_cache_cpu = xpu_forward_meta.decoder_context_len_cache_cpu + + len_info_cpu = xpu_forward_meta.len_info_cpu ( - xpu_forward_meta.encoder_batch_map, - xpu_forward_meta.decoder_batch_map, - xpu_forward_meta.encoder_batch_idx, - xpu_forward_meta.decoder_batch_idx, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.decoder_seq_lod, - xpu_forward_meta.encoder_kv_lod, - xpu_forward_meta.prefix_len, - xpu_forward_meta.decoder_context_len, - xpu_forward_meta.decoder_context_len_cache, - xpu_forward_meta.prefix_block_tables, - xpu_forward_meta.encoder_batch_map_cpu, - xpu_forward_meta.decoder_batch_map_cpu, - xpu_forward_meta.encoder_batch_idx_cpu, - xpu_forward_meta.decoder_batch_idx_cpu, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.decoder_seq_lod_cpu, - xpu_forward_meta.encoder_kv_lod_cpu, - xpu_forward_meta.prefix_len_cpu, - xpu_forward_meta.decoder_context_len_cpu, - xpu_forward_meta.decoder_context_len_cache_cpu, - xpu_forward_meta.len_info_cpu, - xpu_forward_meta.slot_mapping_enc, - xpu_forward_meta.slot_mapping_dec, + slot_mapping_enc, + slot_mapping_dec, ) = get_infer_param( seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, - xpu_forward_meta.block_tables, + block_tables, + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + encoder_kv_lod, + prefix_len, + decoder_context_len, + decoder_context_len_cache, + prefix_block_tables, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_kv_lod_cpu, + prefix_len_cpu, + decoder_context_len_cpu, + decoder_context_len_cache_cpu, + len_info_cpu, block_size, num_speculative_tokens, ) - xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0] - xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1] - xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2] adjusted_input = adjust_batch( ids_remove_padding.reshape([-1, 1]), - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.decoder_seq_lod, - xpu_forward_meta.encoder_batch_idx, - xpu_forward_meta.decoder_batch_idx, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.decoder_seq_lod_cpu, - xpu_forward_meta.encoder_batch_idx_cpu, - xpu_forward_meta.decoder_batch_idx_cpu, - xpu_forward_meta.len_info_cpu, + encoder_seq_lod, + decoder_seq_lod, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + len_info_cpu, None, # output_padding_offset -1, # max bs ) @@ -229,17 +272,22 @@ def xpu_pre_process( adjusted_input = adjusted_input.squeeze(1) share_inputs["ids_remove_padding"].copy_(adjusted_input, False) + + xpu_forward_meta.enc_batch = len_info_cpu[0] + xpu_forward_meta.dec_batch = len_info_cpu[1] + xpu_forward_meta.total_enc_len = len_info_cpu[2] xpu_forward_meta.ids_remove_padding = adjusted_input - # Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends + # Set xpu_forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends xpu_forward_meta.is_profiling = is_profiling - if use_cudagraph: - if forward_meta is None: - return xpu_forward_meta - else: - forward_meta.copy_from(xpu_forward_meta) - return forward_meta + + # prefill does not use cudagraph, inplace copy is not needed + xpu_forward_meta.slot_mapping_enc = slot_mapping_enc + if use_cudagraph and forward_meta is not None: + xpu_forward_meta.slot_mapping_dec.copy_(slot_mapping_dec, False) else: - return xpu_forward_meta + xpu_forward_meta.slot_mapping_dec = slot_mapping_dec + + return xpu_forward_meta def xpu_process_output(