Skip to content

[Bug]: FMHA kernels not found for W4A8 AWQ with FP8 KV Cache #9253

@DearPlanet

Description

@DearPlanet

System Info

  • Hardware: L20

  • CUDA Version: 12.9

The docker env is compiled from TensorRT LLM source code, tag 1.2.0rc1, just downgrade CUDA to 12.9:

  • Python 3.12.3
  • tensorrt-llm 1.2.0rc1
  • tensorrt 10.11.0.33
  • torch 2.8.0a0+5228986c39.nv25.6ompile

In addition, I have tried official image of 1.2.0rc1: nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc1 with CUDA 13.0 and got the same error.

Who can help?

@Tracin

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. First, for Qwen2.5-72B, should pad for intermediate_size like this, otherwise awq + groupsize=128 can not be applied in quantization:
# need transformers==4.46.3
import json
import os
from collections import OrderedDict
from typing import Dict

import torch
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    shard_checkpoint,
)

def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool) -> str:
    qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
    for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
        if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
            with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
                for key in f.keys():
                    qwen_state_dict[key] = f.get_tensor(key)

    qwen2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
    torch_dtype = None
    for key, value in tqdm(qwen_state_dict.items(), desc="Convert format"):
        if torch_dtype is None:
            torch_dtype = value.dtype
        shape_list = [int(i) for i in value.shape]
        if len(shape_list) == 2:
            if shape_list[0] == 29568:
                value = torch.concat((value, torch.zeros([128, shape_list[1]], dtype=value.dtype)), dim=0)
            if shape_list[1] == 29568:
                value = torch.cat((value, torch.zeros([shape_list[0], 128], dtype=value.dtype)), dim=1)
        qwen2_state_dict[key] = value

    weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
    shards, index = shard_checkpoint(qwen2_state_dict, max_shard_size=shard_size, weights_name=weights_name)

    for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
        if save_safetensors:
            save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
        else:
            torch.save(shard, os.path.join(output_dir, shard_file))

    if index is None:
        print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
    else:
        index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
        with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
            json.dump(index, f, indent=2, sort_keys=True)
        print("Model weights saved in {}".format(output_dir))

    return str(torch_dtype).replace("torch.", "")

if __name__ == "__main__":
    save_weight(input_dir="/path/to/models/Qwen2.5-72B-Instruct", 
                output_dir="/path/to/models/Qwen2.5-72B-Instruct-Padding", 
                shard_size="4GB", save_safetensors=True)
  1. Then, using TensorRT Model Optimizer to get an W4A8 quantized model, just follow this script and modify config to W4A8_AWQ_BETA_CFG:
    Model Optimizer PTQ sample

  2. Run with trtllm serve

trtllm-serve serve /path/to/models/Qwen2.5-72B-Instruct-W4A8-test/ --tp_size 2 --extra_llm_api_options /path/to/cfg/qwen2.5-72b-w4a8.yaml"

The llm api config details:

trust_remote_code: true
enable_attention_dp: false
cuda_graph_config:
  enable_padding: true
  max_batch_size: 128
  enable_iter_perf_stats: true
kv_cache_config:
  dtype: fp8
  enable_partial_reuse: false
  free_gpu_memory_fraction: 0.92

Expected behavior

TensorRT LLM should allow W4A8 + FP8 KV Cache combination.

actual behavior

The engine will throw out FMHA kernels are not found with these parameter error during initialization:

Error logs:
RuntimeError: [TensorRT-LLM][ERROR] Assertion failed: FMHA kernels are not found with these parameters:

S : 0

D : 128

DV : 128

AttentionMaskType : 1

AttentionInputLayout : 2

AttnLogitSoftcapping : 0

AlibiSupported : 1

WarpSpecialization : 0

Tiled : 0

FP32Accumulation : 1

FlashAttention : 1

Interleaved : 0

ReturnSoftmaxStats : 1

Unroll : 1

Hash : 0x0000000401000d4d

Available kernel functions:

Meta[95]:

S : 0

D : 128

DV : 128

AttentionMaskType : 0

AttentionInputLayout : 0

AttnLogitSoftcapping : 0

AlibiSupported : 1

WarpSpecialization : 0

Tiled : 0

FP32Accumulation : 1

FlashAttention : 1

Interleaved : 0

ReturnSoftmaxStats : 1

Unroll : 1

Hash : 0x000c28a40100014d

Meta[93]:

S : 0

D : 192

DV : 192

AttentionMaskType : 3

AttentionInputLayout : 2

AttnLogitSoftcapping : 0

AlibiSupported : 1

WarpSpecialization : 0

Tiled : 0

FP32Accumulation : 1

FlashAttention : 1

Interleaved : 0

ReturnSoftmaxStats : 1

Unroll : 1

Hash : 0x0000000601801d4d

Meta[92]:

S : 0

D : 192

DV : 192

AttentionMaskType : 2

AttentionInputLayout : 2

AttnLogitSoftcapping : 0

AlibiSupported : 1

WarpSpecialization : 0

Tiled : 0

FP32Accumulation : 1

FlashAttention : 1

Interleaved : 0

ReturnSoftmaxStats : 1

Unroll : 1

Hash : 0x000000060180154d

Meta[94]:

S : 0

D : 80

DV : 80

AttentionMaskType : 0

AttentionInputLayout : 0

AttnLogitSoftcapping : 0

AlibiSupported : 1

WarpSpecialization : 0

Tiled : 0

FP32Accumulation : 1

FlashAttention : 1

Interleaved : 0

ReturnSoftmaxStats : 1

Unroll : 1

Hash : 0x000c28a280a0014d

Meta[91]:

S : 0

D : 192

DV : 192

AttentionMaskType : 1

AttentionInputLayout : 2

AttnLogitSoftcapping : 0

AlibiSupported : 1

WarpSpecialization : 0

Tiled : 0

FP32Accumulation : 1

FlashAttention : 1

Interleaved : 0

ReturnSoftmaxStats : 1

Unroll : 1

Hash : 0x0000000601800d4d

Meta[90]:

S : 0

D : 192

DV : 192

AttentionMaskType : 3

AttentionInputLayout : 0

AttnLogitSoftcapping : 0

AlibiSupported : 1

WarpSpecialization : 0

Tiled : 0

FP32Accumulation : 1

FlashAttention : 1

Interleaved : 0

ReturnSoftmaxStats : 1

Unroll : 1

Hash : 0x000000060180194d

Meta[89]:

S : 0

D : 192

DV : 192

AttentionMaskType : 2

AttentionInputLayout : 0

AttnLogitSoftcapping : 0

AlibiSupported : 1

WarpSpecialization : 0

Tiled : 0

FP32Accumulation : 1

FlashAttention : 1

Interleaved : 0

ReturnSoftmaxStats : 1

Unroll : 1

Hash : 0x000000060180114d

Meta[88]:

S : 0

D : 192

DV : 192

AttentionMaskType : 1

AttentionInputLayout : 0

AttnLogitSoftcapping : 0

AlibiSupported : 1

WarpSpecialization : 0

Tiled : 0

FP32Accumulation : 1

FlashAttention : 1

Interleaved : 0

ReturnSoftmaxStats : 1

Unroll : 1

Hash : 0x000000060180094d

(/src/tensorrt_llm/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.cpp:277)

1 0x7fdfe5824f06 tensorrt_llm::common::throwRuntimeError(char const*, int, char const*) + 97

2 0x7fdfbff4fe41 /usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x1b4fe41) [0x7fdfbff4fe41]

3 0x7fdfc01f24e4 int tensorrt_llm::common::op::AttentionOp::enqueueContext<__nv_bfloat16, tensorrt_llm::kernels::KVBlockArray>(tensorrt_llm::common::op::AttentionOp::EnqueueContextParams<__nv_bfloat16> const&, CUstream_st*) + 10500

4 0x7fdfe5940d3f torch_ext::trtllm::attention::Runner<__nv_bfloat16, __nv_bfloat16>::run(tensorrt_llm::common::op::AttentionOp&, bool, int, int, int, int, int, at::Tensor, at::Tensor, std::optional<at::Tensor>, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, at::Tensor, at::Tensor, int, at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::vector<std::optional<at::Tensor>, std::allocator<std::optional<at::Tensor> > >, std::optional<at::Tensor>, c10::ArrayRef<std::optional<at::Tensor> >, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>) const + 5055

5 0x7fdfe592c712 torch_ext::attention(at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, at::Tensor&, std::optional<at::Tensor>, std::optional<c10::ScalarType>, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, bool, bool, long, long, long, long, long, std::optional<long>, long, long, long, long, long, long, long, double, long, long, double, long, std::vector<double, std::allocator<double> >, std::vector<long, std::allocator<long> >, bool, std::optional<long>, bool, std::optional<long>, std::optional<long>, std::optional<long>, std::optional<long>, std::optional<long>, std::optional<long>, std::optional<at::Tensor>, std::optional<at::Tensor>, std::vector<std::optional<at::Tensor>, std::allocator<std::optional<at::Tensor> > >, std::optional<long>, std::optional<at::Tensor>, std::vector<bool, std::allocator<bool> >, std::vector<std::optional<at::Tensor>, std::allocator<std::optional<at::Tensor> > >, std::vector<std::optional<at::Tensor>, std::allocator<std::optional<at::Tensor> > >) + 8594

6 0x7fdfea157c71 /usr/local/lib/python3.12/dist-packages/tensorrt_llm/bindings.cpython-312-x86_64-linux-gnu.so(+0x157c71) [0x7fdfea157c71]

7 0x7fdfea1bc9f1 /usr/local/lib/python3.12/dist-packages/tensorrt_llm/bindings.cpython-312-x86_64-linux-gnu.so(+0x1bc9f1) [0x7fdfea1bc9f1]

8 0x5db55b _PyEval_EvalFrameDefault + 19483

9 0x54cd94 /usr/bin/python() [0x54cd94]

10 0x54b3b5 PyObject_Call + 277

11 0x5db55b _PyEval_EvalFrameDefault + 19483

12 0x54cd94 /usr/bin/python() [0x54cd94]

13 0x54b3b5 PyObject_Call + 277

14 0x5db55b _PyEval_EvalFrameDefault + 19483

15 0x54aa9a _PyObject_Call_Prepend + 394

16 0x5a3628 /usr/bin/python() [0x5a3628]

17 0x54b30c PyObject_Call + 108

18 0x5db55b _PyEval_EvalFrameDefault + 19483

19 0x54cd94 /usr/bin/python() [0x54cd94]

20 0x54b3b5 PyObject_Call + 277

21 0x5db55b _PyEval_EvalFrameDefault + 19483

22 0x54cd94 /usr/bin/python() [0x54cd94]

23 0x54b3b5 PyObject_Call + 277

24 0x5db55b _PyEval_EvalFrameDefault + 19483

25 0x54aa9a _PyObject_Call_Prepend + 394

26 0x5a3628 /usr/bin/python() [0x5a3628]

27 0x54924e _PyObject_MakeTpCall + 318

28 0x5d73c9 _PyEval_EvalFrameDefault + 2697

29 0x54cd94 /usr/bin/python() [0x54cd94]

30 0x54b3b5 PyObject_Call + 277

31 0x5db55b _PyEval_EvalFrameDefault + 19483

32 0x54cd94 /usr/bin/python() [0x54cd94]

33 0x54b3b5 PyObject_Call + 277

34 0x5db55b _PyEval_EvalFrameDefault + 19483

35 0x54aa9a _PyObject_Call_Prepend + 394

36 0x5a3628 /usr/bin/python() [0x5a3628]

37 0x54924e _PyObject_MakeTpCall + 318

38 0x5d73c9 _PyEval_EvalFrameDefault + 2697

39 0x54cd94 /usr/bin/python() [0x54cd94]

40 0x54b3b5 PyObject_Call + 277

41 0x5db55b _PyEval_EvalFrameDefault + 19483

42 0x54cd94 /usr/bin/python() [0x54cd94]

43 0x54b3b5 PyObject_Call + 277

44 0x5db55b _PyEval_EvalFrameDefault + 19483

45 0x54aa9a _PyObject_Call_Prepend + 394

46 0x59e09f /usr/bin/python() [0x59e09f]

47 0x599b63 /usr/bin/python() [0x599b63]

48 0x54b30c PyObject_Call + 108

49 0x5db55b _PyEval_EvalFrameDefault + 19483

50 0x54aa9a _PyObject_Call_Prepend + 394

51 0x59e09f /usr/bin/python() [0x59e09f]

52 0x599b63 /usr/bin/python() [0x599b63]

53 0x54924e _PyObject_MakeTpCall + 318

54 0x5d73c9 _PyEval_EvalFrameDefault + 2697

55 0x5d58eb PyEval_EvalCode + 347

56 0x5d347c /usr/bin/python() [0x5d347c]

57 0x581f0d /usr/bin/python() [0x581f0d]

58 0x549b85 PyObject_Vectorcall + 53

59 0x5d73c9 _PyEval_EvalFrameDefault + 2697

60 0x6bcce2 /usr/bin/python() [0x6bcce2]

61 0x6bc912 Py_RunMain + 562

62 0x6bc57d Py_BytesMain + 45

63 0x7fe17722a1ca /usr/lib/x86_64-linux-gnu/libc.so.6(+0x2a1ca) [0x7fe17722a1ca]

64 0x7fe17722a28b __libc_start_main + 139

65 0x657ce5 _start + 37

The above exception was the direct cause of the following exception:

Traceback (most recent call last):

File "/usr/local/bin/trtllm-serve", line 8, in <module>

sys.exit(main())

^^^^^^

File "/usr/local/lib/python3.12/dist-packages/click/core.py", line 1442, in __call__

return self.main(*args, **kwargs)

^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/click/core.py", line 1363, in main

rv = self.invoke(ctx)

^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/click/core.py", line 1830, in invoke

return _process_result(sub_ctx.command.invoke(sub_ctx))

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/click/core.py", line 1226, in invoke

return ctx.invoke(self.callback, **ctx.params)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/click/core.py", line 794, in invoke

return callback(*args, **kwargs)

^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/commands/serve.py", line 374, in serve

launch_server(host, port, llm_args, metadata_server_cfg, server_role)

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/commands/serve.py", line 148, in launch_server

llm = PyTorchLLM(**llm_args)

^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/llmapi/llm.py", line 1098, in __init__

super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/llmapi/llm.py", line 986, in __init__

super().__init__(model,

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/llmapi/llm.py", line 228, in __init__

self._build_model()

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/llmapi/llm.py", line 1045, in _build_model

self._executor = self._executor_cls.create(

^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/executor/executor.py", line 495, in create

return GenerationExecutor._create_ipc_executor(

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/executor/executor.py", line 414, in _create_ipc_executor

return GenerationExecutorProxy(

^^^^^^^^^^^^^^^^^^^^^^^^

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/executor/proxy.py", line 104, in __init__

self._start_executor_workers(worker_kwargs)

File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/executor/proxy.py", line 345, in _start_executor_workers

raise RuntimeError(

RuntimeError: Executor worker returned error

additional notes

I'm trying w4a8 awq quantization of Qwen2.5-72B-Instruct, and serve it with trtllm serve on NVIDIA L20 tp2.

When using w4a8 awq model and fp16 kv cache (precision auto), everything works well.

But if I want to switch to fp8 kv cache, either static or dynamic quantization, the engine initialization will throw out
FMHA kernels are not found with these parameter error. I wonder if some kernel settings lost and lead to this error.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.

Metadata

Metadata

Assignees

Labels

Customized kernels<NV>Specialized/modified CUDA kernels in TRTLLM for LLM ops, beyond standard TRT. Dev & perf.KV-Cache Managementkv-cache management for efficient LLM inferenceLow PrecisionLower-precision formats (INT8/INT4/FP8) for TRTLLM quantization (AWQ, GPTQ).bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions