Skip to content

Conversation

@michaelfeil
Copy link
Contributor

@michaelfeil michaelfeil commented Nov 21, 2025

What does this PR do?

This PR open sources the implementation of RadixMLP for TEI. (arxiv link pending)

High-throughput LLM batch inference relies a careful consideration of shared prefix and scheduling, as well as implementation of cached input tokens. However, modern inference systems fail to exploit the KV Cache, 1) due to scheduling complexities, previously evicted the KV Cache, or 2) or failure to implement a full KV management system RadixTree, paged memory, and swapping kernels. To address this problem, RadixMLP implements a deduplication technique that skips identical computation within a single batch, exploiting the principle that an identical set of token ids and positions can be referenced by multiple branching sequences. Duplicate entries in the inference can be folded before the MLP-Block, and the resulting values can be scattered back to the original input shape.

Conceptually the code in radix_mlp.rs, and the tests are propably most helpful:

        // tokens    = [a,b,c,d,e,f,g, a,b,c, e,f,g,h,i]
        // pos       = [0,1,2,3,4,5,6, 0,1,2, 3,4,5,6,7]
        // cu_seqlen = [0,7,10,15]
        // Expected folded:
        // tokens    = [a,b,c, d,e,f,g, e,f,g,h,i]
        // pos       = [0,1,2, 3,4,5,6, 3,4,5,6,7]

Performance

Small synthetic performance benchmark:

  • 32 clients send each 1 json request, each request is the same
  • 1 json request = 512 sentences, each sentence is unique

Conceptually - How this RadixMLP benefits:

  • if two sentences are similar (e.g. same sentence from 2 clients), we can deduplicate the mlp forward pass.
  • assuming thoughput is bound 1/3 by attention layer (512 ctx length) and 2/3 mlp layer, we should be able to reduce the mlp effort by a lot say to 0.2x (best case 1/32, requires very large batch size), getting a total speedup of 1 -> 1/3+2/3*0.2 ~= 0.46 or throughput improvement of ~213%.
ab -n 32 -c 32 -l -s 480 -T 'application/json' -p ./embed_throughput.json "http://0.0.0.0:7997/v1/embeddings"
# embed_throughput.json is a file with 512 x 512 tokens sentences.

RadixMLP disabled

(near zero overhead, same as before PR, hidden behind feature flags --radix-mlp-threshold 0.0.)
Requests per second: 0.49 [#/sec] (mean)

text-embeddings-router --model-id michaelfeil/Qwen3-Embedding-8B-auto --max-batch-tokens 40960 --port 7997 --max-client-batch-size 512 --radix-mlp-threshold 0.0

RadixMLP enabled

(radix mlp is enabled for every batch that is dispatched by queue.rs with at least 10% of tokens saved in the mlp layer.
Requests per second: 0.85 [#/sec] (mean)

text-embeddings-router --model-id michaelfeil/Qwen3-Embedding-8B-auto --max-batch-tokens 40960 --port 7997 --max-client-batch-size 512 --radix-mlp-threshold 0.9

RadixMLP enable large batch size

Requests per second:    0.89 [#/sec] (mean)

You can verify for now that radix-mlp is enabled and used by:

2025-11-21T16:29:13.274086Z  INFO batching_task:next_batch: text_embeddings_core::queue: core/src/queue.rs:199: RadixMLP compression ratio: 0.87 (48125 -> 41919)
2025-11-21T16:29:14.406218Z  INFO openai_embed{total_time="26.904671951s" tokenization_time="4.666986ms" queue_time="2.304806018s" inference_time="1.75808412s"}: text_embeddings_router::http::server: router/src/http/server.rs:1290: Success
2025-11-21T16:29:14.410873Z  INFO batching_task:next_batch: text_embeddings_core::queue: core/src/queue.rs:199: RadixMLP compression ratio: 0.32 (140337 -> 44474)

How:

  • only enabled for causal lms (qwen, llama, mistral) using a Backend trait, similar to is_padded

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case. // Private stack conversation.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines.
  • Did you write any new necessary tests? If applicable, did you include or update the insta snapshots?

Who can review?

@alvarobartt

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@michaelfeil michaelfeil marked this pull request as draft November 21, 2025 15:14
@michaelfeil
Copy link
Contributor Author

michaelfeil commented Nov 21, 2025

Tests are passing
 Compiling text-embeddings-backend-candle v1.8.3 (/workspace/model-performance/kebao-node2/text-embeddings-inference/backends/candle)
 Compiling text-embeddings-backend v1.8.3 (/workspace/model-performance/kebao-node2/text-embeddings-inference/backends)
 Compiling text-embeddings-core v1.8.3 (/workspace/model-performance/kebao-node2/text-embeddings-inference/core)
 Compiling text-embeddings-router v1.8.3 (/workspace/model-performance/kebao-node2/text-embeddings-inference/router)
  Finished `test` profile [unoptimized + debuginfo] target(s) in 14.56s
   Running unittests src/lib.rs (/node-storage/cargo-target/debug/deps/backend_grpc_client-fbff104cad313831)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running unittests src/lib.rs (/node-storage/cargo-target/debug/deps/text_embeddings_backend-723961cb43569a29)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running unittests src/lib.rs (/node-storage/cargo-target/debug/deps/text_embeddings_backend_candle-69e05a50e10dcc85)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/common.rs (/node-storage/cargo-target/debug/deps/common-d6094a4d78f4ceea)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_bert.rs (/node-storage/cargo-target/debug/deps/test_bert-0e729260929e716c)

running 4 tests
test test_bert_pooled_raw ... ok
test test_bert ... ok
test test_emotions ... ok
test test_bert_classification ... ok

test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 6.87s

   Running tests/test_dense.rs (/node-storage/cargo-target/debug/deps/test_dense-6106d23bbaf60659)

running 2 tests
test test_stella_en_400m_v5_default_dense ... ok
test test_stella_en_400m_v5_dense_768 ... ok

test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 20.63s

   Running tests/test_flash_bert.rs (/node-storage/cargo-target/debug/deps/test_flash_bert-010c87f5a13db687)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_flash_gte.rs (/node-storage/cargo-target/debug/deps/test_flash_gte-6ba299a781ad8a19)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_flash_jina.rs (/node-storage/cargo-target/debug/deps/test_flash_jina-db5f6c30ddfef3c3)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_flash_jina_code.rs (/node-storage/cargo-target/debug/deps/test_flash_jina_code-6e5bbc7ee81fffad)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_flash_mistral.rs (/node-storage/cargo-target/debug/deps/test_flash_mistral-758849dcf83d6d2b)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_flash_nomic.rs (/node-storage/cargo-target/debug/deps/test_flash_nomic-796880e608950cfe)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_flash_qwen2.rs (/node-storage/cargo-target/debug/deps/test_flash_qwen2-050acb3dd261b0bb)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_flash_qwen3.rs (/node-storage/cargo-target/debug/deps/test_flash_qwen3-1175a4089d8a7625)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_gemma3.rs (/node-storage/cargo-target/debug/deps/test_gemma3-cd1547f1300512c3)

running 1 test
test test_gemma3 ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 8.23s

   Running tests/test_gte.rs (/node-storage/cargo-target/debug/deps/test_gte-1a6ee5baa700cd58)

running 4 tests
test test_alibaba_gte_new ... ok
test test_alibaba_gte ... ok
test test_gte_classification ... ok
test test_snowflake_gte ... ok

test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 41.75s

   Running tests/test_jina.rs (/node-storage/cargo-target/debug/deps/test_jina-4b17c4da80b92b15)

running 2 tests
test test_jina_rerank has been running for over 60 seconds
test test_jina_small has been running for over 60 seconds
test test_jina_small ... ok
test test_jina_rerank ... ok

test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 228.18s

   Running tests/test_jina_code.rs (/node-storage/cargo-target/debug/deps/test_jina_code-601ea68f0e760e1a)

running 1 test
test test_jina_code_base has been running for over 60 seconds
test test_jina_code_base ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 247.78s

   Running tests/test_modernbert.rs (/node-storage/cargo-target/debug/deps/test_modernbert-218ace6649a1c1b3)

running 4 tests
test test_modernbert ... ok
test test_modernbert_classification_mean_pooling ... ok
test test_modernbert_pooled_raw ... ok
test test_modernbert_classification ... ok

test result: ok. 4 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 18.97s

   Running tests/test_mpnet.rs (/node-storage/cargo-target/debug/deps/test_mpnet-34c9e708cc70a189)

running 2 tests
test test_mpnet_pooled_raw ... ok
test test_mpnet ... ok

test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 5.31s

   Running tests/test_nomic.rs (/node-storage/cargo-target/debug/deps/test_nomic-01fc2e318753abc2)

running 2 tests
test test_nomic_small ... ok
test test_nomic_moe ... ok

test result: ok. 2 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 7.07s

   Running tests/test_qwen3.rs (/node-storage/cargo-target/debug/deps/test_qwen3-381e19f4029ce144)

running 1 test
test test_qwen3 ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 17.36s

   Running unittests src/lib.rs (/node-storage/cargo-target/debug/deps/text_embeddings_backend_core-fe53607928c154ed)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running unittests src/lib.rs (/node-storage/cargo-target/debug/deps/text_embeddings_backend_ort-402b26c59adcef85)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running unittests src/lib.rs (/node-storage/cargo-target/debug/deps/text_embeddings_backend_python-dae5fe18ce71a96c)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running unittests src/lib.rs (/node-storage/cargo-target/debug/deps/text_embeddings_core-6bb3306508879b75)

running 16 tests
test radix_mlp::tests::test_compute_fold_and_scatter_different_positions ... ok
test radix_mlp::tests::test_compute_fold_and_scatter_deterministic_ordering ... ok
test radix_mlp::tests::test_compute_fold_and_scatter_example_from_comments ... ok
test radix_mlp::tests::test_compute_fold_and_scatter_empty ... ok
test radix_mlp::tests::test_compute_fold_and_scatter_edge_case_single_token ... ok
test radix_mlp::tests::test_compute_fold_and_scatter_identical_sequences ... ok
test radix_mlp::tests::test_compute_fold_and_scatter_no_overlap ... ok
test radix_mlp::tests::test_compute_fold_and_scatter_partial_overlap ... ok
test radix_mlp::tests::test_compute_fold_and_scatter_three_sequences_complex ... ok
test radix_mlp::tests::test_compute_fold_and_scatter_single_sequence ... ok
test radix_mlp::tests::test_fold_gather_points_to_first_occurrence ... ok
test radix_mlp::tests::test_padding_to_multiple_of_8 ... ok
test radix_mlp::tests::test_radix_mlp_edge_cases_parameterized ... ok
test radix_mlp::tests::test_radix_mlp_reconstruction_parameterized ... ok
test radix_mlp::tests::fail_and_report_time_large_batch ... ok
test tokenization::tests::tokenizer ... ok

test result: ok. 16 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 2.22s

   Running unittests src/lib.rs (/node-storage/cargo-target/debug/deps/text_embeddings_router-c54f9d15e115ab43)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running unittests src/main.rs (/node-storage/cargo-target/debug/deps/text_embeddings_router-d90db65eb377ff58)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/common.rs (/node-storage/cargo-target/debug/deps/common-52d91fef52ad7b9e)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_http_embed.rs (/node-storage/cargo-target/debug/deps/test_http_embed-b56e231d358bc322)

running 1 test
test test_mrl_embeddings ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 5.38s

   Running tests/test_http_predict.rs (/node-storage/cargo-target/debug/deps/test_http_predict-601c0276d93bca83)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

   Running tests/test_http_rerank.rs (/node-storage/cargo-target/debug/deps/test_http_rerank-97187e2a5f06a515)

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

 Doc-tests backend_grpc_client

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

 Doc-tests text_embeddings_backend

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

 Doc-tests text_embeddings_backend_candle

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

 Doc-tests text_embeddings_backend_core

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

 Doc-tests text_embeddings_backend_ort

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

 Doc-tests text_embeddings_backend_python

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

 Doc-tests text_embeddings_core

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

 Doc-tests text_embeddings_router

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

@michaelfeil michaelfeil marked this pull request as ready for review November 21, 2025 17:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant