Skip to content

Commit c39d22d

Browse files
committed
Revert "[https://nvbugs/5567586][feat] Ampere xqa swa specdec for GPT-OSS Eagle3-one-model (NVIDIA#8383)"
This reverts commit 0a09465. Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent c85a4e0 commit c39d22d

File tree

10 files changed

+71
-225
lines changed

10 files changed

+71
-225
lines changed

cpp/kernels/xqa/mha.cu

Lines changed: 12 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -466,53 +466,20 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
466466
#define MMAS_N_PER_MASK 2
467467

468468
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
469-
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
470-
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
471-
,
472-
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
473-
#endif
474-
)
469+
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
475470
{
476471
uint32_t const idxInQuad = laneId() % 4;
477472
uint32_t const idxQuad = laneId() / 4;
478473
// Packed mask is aligned with 32 bits (2 uint16_t).
479474
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
480475
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
481-
constexpr uint64_t fullMask = ~uint64_t{0};
482-
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
483-
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
484-
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
485-
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
486-
assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange));
487-
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
488-
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
489-
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
490-
#else
491-
constexpr bool ctaNeedBegMask = false;
492-
bool const ctaNeedSpecDecMask = true;
493-
int32_t const tok0NbMaskOut = -2147483648;
494-
#endif
495-
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
496-
497-
if (!needMask)
498-
{
499-
return;
500-
}
501476
#pragma unroll
502477
for (uint32_t m = 0; m < acc.rows; m++)
503478
{
504479
#pragma unroll
505480
for (uint32_t i = 0; i < InstAcc::rows; i++)
506481
{
507-
uint32_t const idxQTokInCta = (rowOffset + instM * m + idxQuad + i * 8) / headGrpSize;
508-
uint32_t const tokenRow = min(idxQTokInCta, actualQSeqLen - 1);
509-
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
510-
int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta);
511-
uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask);
512-
#else
513-
uint64_t const begMask = fullMask;
514-
#endif
515-
482+
uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1);
516483
#pragma unroll
517484
for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
518485
{
@@ -524,15 +491,12 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
524491
uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
525492
? 0u
526493
: min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
494+
uint32_t packedMask = 0u;
527495
uint32_t const maskPosStart = (maskPos0 / 16) * 16;
528-
uint32_t packedMask = ~uint32_t{0};
529-
if (ctaNeedSpecDecMask)
530-
{
531-
reinterpret_cast<uint16_t*>(&packedMask)[0]
532-
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
533-
reinterpret_cast<uint16_t*>(&packedMask)[1]
534-
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
535-
}
496+
reinterpret_cast<uint16_t*>(&packedMask)[0]
497+
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
498+
reinterpret_cast<uint16_t*>(&packedMask)[1]
499+
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
536500
#pragma unroll
537501
for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++)
538502
{
@@ -546,11 +510,7 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
546510
bool const maskFlag = col + actualQSeqLen < nbValidCols
547511
? true
548512
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
549-
550-
bool const begMaskFlag = ctaNeedBegMask ? (begMask & (1ULL << col)) : true;
551-
552-
acc(m, n)(i, j)
553-
= maskFlag && begMaskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
513+
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
554514
}
555515
}
556516
}
@@ -1651,14 +1611,8 @@ CUBIN_EXPORT __global__
16511611
#endif
16521612

16531613
uint32_t const cacheSeqLen = getCacheSeqLen<usePagedKVCache>(cacheList, idxReq);
1654-
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1655-
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
1656-
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
1657-
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);
1658-
1659-
#elif SLIDING_WINDOW
1614+
#if SLIDING_WINDOW
16601615
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
1661-
assert(!SPEC_DEC || !rtIsReallySliding);
16621616
uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0;
16631617
#else
16641618
constexpr bool rtIsReallySliding = false;
@@ -1672,9 +1626,7 @@ CUBIN_EXPORT __global__
16721626
#endif
16731627

16741628
uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
1675-
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
1676-
uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
1677-
#elif SPEC_DEC
1629+
#if SPEC_DEC
16781630
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
16791631
#endif
16801632

@@ -1960,12 +1912,8 @@ CUBIN_EXPORT __global__
19601912
if (seqIter >= nbSeqItersWithoutMask)
19611913
{
19621914
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
1963-
applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
1964-
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
1965-
,
1966-
tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
1967-
#endif
1968-
);
1915+
applyMaskFromInput(
1916+
warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
19691917
}
19701918
#else
19711919
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);

jenkins/L0_Test.groovy

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2895,7 +2895,6 @@ def launchTestJobs(pipeline, testFilter)
28952895

28962896
x86SlurmTestConfigs = [
28972897
"DGX_H100-2_GPUs-PyTorch-Others-1": ["dgx-h100-x2-oci", "l0_dgx_h100", 1, 1, 2],
2898-
"DGX_H100-2_GPUs-PyTorch-GptOss-1": ["dgx-h100-x2-oci", "l0_dgx_h100", 1, 1, 2],
28992898
"DGX_H100-2_GPUs-PyTorch-Ray-1": ["dgx-h100-x2-oci", "l0_dgx_h100", 1, 1, 2],
29002899
"DGX_H100-4_GPUs-PyTorch-DeepSeek-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4],
29012900
"DGX_H100-4_GPUs-PyTorch-GptOss-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4],

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def run(
475475
self.spec_decoding_generation_lengths,
476476
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
477477
]
478-
if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()):
478+
if get_sm_version() >= 100:
479479
spec_decoding_tensor_params.append(
480480
self.spec_decoding_bl_tree_mask_offset)
481481
spec_decoding_tensor_params.append(self.spec_decoding_bl_tree_mask)
@@ -1219,12 +1219,12 @@ def update_spec_dec_param(
12191219

12201220
# spec_dec mode should only be enabled for non-sm100 machines and when there's a spec-dec tree.
12211221
self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
1222-
not self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()))
1222+
get_sm_version() < 100 or get_sm_version() == 120)
12231223

12241224
self.is_spec_dec_tree = spec_tree_manager is not None
12251225
self.is_spec_dec_dynamic_tree = spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree
12261226

1227-
if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()):
1227+
if get_sm_version() >= 100 and get_sm_version() != 120:
12281228
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
12291229
assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
12301230

@@ -1260,7 +1260,7 @@ def update_spec_dec_param(
12601260
device='cuda',
12611261
)
12621262

1263-
if self.is_sm_version_trtllm_gen_kernel(sm=get_sm_version()):
1263+
if get_sm_version() >= 100:
12641264
self.spec_decoding_param_prepare_for_blackwell()
12651265
else:
12661266
self.spec_decoding_bl_tree_mask_offset = None
@@ -1371,9 +1371,6 @@ def generate_spec_decoding_generation_length(self, max_draft_len):
13711371
self.spec_decoding_generation_lengths[:self.max_num_requests].copy_(
13721372
spec_decoding_generation_length, non_blocking=True)
13731373

1374-
def is_sm_version_trtllm_gen_kernel(self, sm):
1375-
return not (sm < 100 or sm in [120, 121])
1376-
13771374

13781375
class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
13791376

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 3 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -4248,16 +4248,14 @@ def test_w4_chunked_prefill(self, kv_cache_dtype, moe_backend, mocker):
42484248
["CUTLASS",
42494249
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
42504250
ids=["cutlass", "trtllm", "triton"])
4251-
def test_eagle3_4gpus(self, moe_backend, one_model, overlap_scheduler,
4252-
mocker):
4251+
def test_eagle3(self, moe_backend, one_model, overlap_scheduler, mocker):
42534252
if moe_backend == "TRITON":
42544253
if not IS_TRITON_KERNELS_AVAILABLE:
42554254
pytest.skip("Triton kernels are not available")
42564255

4257-
if get_sm_version() == 90:
4256+
if get_sm_version() == 90 and moe_backend == "CUTLASS":
42584257
pytest.skip(
4259-
"https://nvbugs/5636916: Remaining Hopper Eagle Accuracy Issue for only TP=4"
4260-
)
4258+
"https://nvbugs/5636916: Remaining Hopper Eagle Accuracy Issue")
42614259

42624260
MAX_OUTPUT_LEN = 128179
42634261
MAX_INPUT_LEN = 32768
@@ -4320,86 +4318,6 @@ def test_eagle3_4gpus(self, moe_backend, one_model, overlap_scheduler,
43204318
sampling_params=sampling_params,
43214319
extra_evaluator_kwargs=extra_evaluator_kwargs)
43224320

4323-
@pytest.mark.skip_less_device(2)
4324-
@pytest.mark.timeout(14400)
4325-
@pytest.mark.parametrize("overlap_scheduler", [True, False],
4326-
ids=["overlap_scheduler", "no_overlap_scheduler"])
4327-
@pytest.mark.parametrize("one_model", [True, False],
4328-
ids=["one_model", "two_model"])
4329-
@pytest.mark.parametrize(
4330-
"moe_backend",
4331-
["CUTLASS",
4332-
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
4333-
ids=["cutlass", "trtllm", "triton"])
4334-
def test_eagle3_2gpus(self, moe_backend, one_model, overlap_scheduler,
4335-
mocker):
4336-
if moe_backend == "TRITON":
4337-
if not IS_TRITON_KERNELS_AVAILABLE:
4338-
pytest.skip("Triton kernels are not available")
4339-
4340-
MAX_OUTPUT_LEN = 128179
4341-
MAX_INPUT_LEN = 32768
4342-
4343-
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
4344-
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
4345-
{"scores_filter": "exact_match,flexible-extract"})
4346-
4347-
mocker.patch.object(GPQADiamond, "MAX_OUTPUT_LEN", MAX_OUTPUT_LEN)
4348-
mocker.patch.object(GPQADiamond, "MAX_INPUT_LEN", MAX_INPUT_LEN)
4349-
4350-
# https://nvbugs/5590408: 2-Model overlap scheduling has accuracy issue
4351-
pytorch_config = dict(
4352-
max_batch_size=8,
4353-
disable_overlap_scheduler=not overlap_scheduler,
4354-
cuda_graph_config=CudaGraphConfig(max_batch_size=8))
4355-
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
4356-
dtype="auto")
4357-
4358-
eagle_model_dir = f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3"
4359-
draft_len = 3
4360-
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
4361-
speculative_model_dir=eagle_model_dir,
4362-
eagle3_one_model=one_model)
4363-
4364-
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
4365-
llm = LLM(self.MODEL_PATH,
4366-
tensor_parallel_size=2,
4367-
pipeline_parallel_size=1,
4368-
moe_expert_parallel_size=1,
4369-
kv_cache_config=kv_cache_config,
4370-
max_seq_len=max_seq_len,
4371-
speculative_config=spec_config,
4372-
**pytorch_config,
4373-
enable_attention_dp=False,
4374-
moe_config=MoeConfig(backend=moe_backend))
4375-
4376-
with llm:
4377-
model_name = "GPT-OSS/120B-MXFP4"
4378-
4379-
# GSM8K
4380-
task = GSM8K(model_name)
4381-
task.evaluate(llm,
4382-
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
4383-
4384-
# GPQA Medium Reasoning
4385-
task = GPQADiamond(model_name)
4386-
4387-
chat_template_kwargs = dict(reasoning_effort="medium")
4388-
extra_evaluator_kwargs = {
4389-
**self.extra_evaluator_kwargs, "chat_template_kwargs":
4390-
chat_template_kwargs
4391-
}
4392-
4393-
sampling_params = SamplingParams(
4394-
temperature=1.0,
4395-
top_p=1.0,
4396-
max_tokens=MAX_OUTPUT_LEN,
4397-
truncate_prompt_tokens=MAX_INPUT_LEN)
4398-
4399-
task.evaluate(llm,
4400-
sampling_params=sampling_params,
4401-
extra_evaluator_kwargs=extra_evaluator_kwargs)
4402-
44034321
@pytest.mark.skip_less_device(4)
44044322
@pytest.mark.skip_device_not_contain(["GB200"])
44054323
@pytest.mark.parametrize(

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -566,18 +566,18 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-au
566566
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto]
567567
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto]
568568
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8]
569-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler]
570-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-no_overlap_scheduler]
571-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-overlap_scheduler]
572-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-no_overlap_scheduler]
573-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-overlap_scheduler]
574-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-no_overlap_scheduler]
575-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-overlap_scheduler]
576-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler]
577-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-one_model-overlap_scheduler]
578-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-one_model-no_overlap_scheduler]
579-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-two_model-overlap_scheduler]
580-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-two_model-no_overlap_scheduler]
569+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model-overlap_scheduler]
570+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model-no_overlap_scheduler]
571+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model-overlap_scheduler]
572+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model-no_overlap_scheduler]
573+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-overlap_scheduler]
574+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-no_overlap_scheduler]
575+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model-overlap_scheduler]
576+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model-no_overlap_scheduler]
577+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model-overlap_scheduler]
578+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model-no_overlap_scheduler]
579+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model-overlap_scheduler]
580+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model-no_overlap_scheduler]
581581
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]
582582
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
583583
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram

tests/integration/test_lists/qa/llm_function_core_sanity.txt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,18 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-au
103103
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto]
104104
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto]
105105
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8]
106-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler]
107-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-no_overlap_scheduler]
108-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-overlap_scheduler]
109-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-two_model-no_overlap_scheduler]
110-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-overlap_scheduler]
111-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-no_overlap_scheduler]
112-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-overlap_scheduler]
113-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler]
114-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-one_model-overlap_scheduler]
115-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-one_model-no_overlap_scheduler]
116-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-two_model-overlap_scheduler]
117-
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[triton-two_model-no_overlap_scheduler]
106+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model-overlap_scheduler]
107+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-one_model-no_overlap_scheduler]
108+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model-overlap_scheduler]
109+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass-two_model-no_overlap_scheduler]
110+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-overlap_scheduler]
111+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-one_model-no_overlap_scheduler]
112+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model-overlap_scheduler]
113+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[trtllm-two_model-no_overlap_scheduler]
114+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model-overlap_scheduler]
115+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-one_model-no_overlap_scheduler]
116+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model-overlap_scheduler]
117+
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[triton-two_model-no_overlap_scheduler]
118118
accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
119119
accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype
120120
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]

0 commit comments

Comments
 (0)