diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index bbbc842dade..317bc8bbee9 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -8,7 +8,8 @@ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_utils import \ MODEL_CLASS_VISION_ENCODER_MAPPING -from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str +from tensorrt_llm._utils import (confidential_compute_enabled, + str_dtype_to_binding, torch_dtype_to_str) from tensorrt_llm.bindings.executor import DecodingMode from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig, EagleDecodingConfig, KvCacheConfig, @@ -828,7 +829,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, max_batch_size: int, speculative_config: SpeculativeConfig, max_beam_width: int, - disable_flash_infer_sampling: bool): + disable_flash_infer_sampling: bool, + enable_async_worker: bool): max_num_sequences = max_batch_size * mapping.pp_size max_draft_len = (0 if speculative_config is None else speculative_config.max_draft_len) @@ -842,6 +844,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, max_num_sequences=max_num_sequences, max_beam_width=max_beam_width, disable_flash_infer_sampling=disable_flash_infer_sampling, + enable_async_worker=enable_async_worker, ) @@ -859,6 +862,9 @@ def instantiate_sampler( kv_cache_config: KvCacheConfig, disable_flash_infer_sampling: bool, ): + enable_async_worker = (confidential_compute_enabled() + or llm_args.enable_sampler_async_worker) + sampler_args = create_torch_sampler_args( mapping, max_seq_len=engine.max_seq_len, @@ -866,6 +872,7 @@ def instantiate_sampler( speculative_config=speculative_config, max_beam_width=max_beam_width, disable_flash_infer_sampling=disable_flash_infer_sampling, + enable_async_worker=enable_async_worker, ) decoding_mode = get_decoding_mode(decoding_config=decoding_config, max_beam_width=max_beam_width) @@ -892,7 +899,8 @@ def instantiate_sampler( max_batch_size=max_batch_size, max_beam_width=max_beam_width, decoding_config=decoding_config, - kv_cache_config=kv_cache_config) + kv_cache_config=kv_cache_config, + enable_async_worker=enable_async_worker) if not engine.model.model_config.is_generation: # NOTE: choose sampler based on model type return EarlyStopSampler() diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index b990310b984..3e4ed9bc2c9 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -52,7 +52,8 @@ LlmResponse, get_draft_token_length) from .model_engine import ModelEngine from .resource_manager import ResourceManager -from .sampler import Sampler, SampleState, SampleStateTensors +from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState, + SampleStateTensors) from .scheduler import RequestScheduler, ScheduledRequests # Environment variable to specify iteration ranges for profiling start/stop. @@ -359,6 +360,10 @@ def start_worker(self): target=self._event_loop_wrapper, daemon=True) self.worker_thread.start() self.worker_started = True + # Start the sampler's async worker, if it is enabled + if (isinstance(self.sampler, AsyncWorkerMixin) + and self.sampler.async_worker_enabled()): + self.sampler.async_worker_start() def _set_global_steady_clock_offset(self): assert self.global_rank >= 0, "rank should be >= 0" @@ -451,6 +456,10 @@ def shutdown(self): keys = list(self.virtual_memory_pools.keys()) for key in keys: del self.virtual_memory_pools[key] + # Stop the sampler's async worker, if it was used + if (isinstance(self.sampler, AsyncWorkerMixin) + and self.sampler.async_worker_enabled()): + self.sampler.async_worker_stop() def can_enqueue_requests(self) -> bool: """ @@ -1600,7 +1609,7 @@ def _forward_step_inter_pp(self, scheduled_batch) -> SampleState: self._update_request_states(scheduled_batch) return self.sampler.SampleState( scheduled_requests=scheduled_batch, - sampler_event=sampler_event, + sampler_event=SamplerEvent(cuda_event=sampler_event), ) def _validate_request(self, request: LlmRequest): diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 99ad3453b7e..15d0c54ddf7 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -17,6 +17,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Iterable +from concurrent import futures from dataclasses import dataclass from functools import cached_property from itertools import repeat @@ -92,6 +93,17 @@ def values(self): return vars(self).values() +@dataclass(kw_only=True) +class SamplerEvent: + cuda_event: torch.cuda.Event + worker_futures: Optional[list[futures.Future[Any]]] = None + + def synchronize(self): + if self.worker_futures: + futures.wait(self.worker_futures) + self.cuda_event.synchronize() + + @dataclass(kw_only=True) class SampleState: scheduled_requests: ScheduledRequests @@ -99,7 +111,7 @@ class SampleState: device: Optional[SampleStateTensors] = None host: Optional[SampleStateTensors] = None - sampler_event: Optional[torch.cuda.Event] = None + sampler_event: Optional[SamplerEvent] = None class Sampler(ABC): @@ -592,7 +604,106 @@ class SampleStateTorch(SampleState): host: SampleStateTensorsHostTorch -class TorchSampler(Sampler): +class AsyncWorkerMixin: + """ + Mixin that adds the ability to fork off operations to run on a worker + thread (particularly D2H copies). If the async worker isn't active, + operations will seamlessly run on the main thread + """ + + MAX_WORKERS = 1 + + def _async_worker_active(self) -> bool: + return hasattr(self, "_async_worker") and self._async_worker is not None + + def _async_worker_init(self, enable_async_worker: bool): + self._enable_async_worker = enable_async_worker + self._async_worker = None + self._async_worker_futures: list[futures.Future[any]] = [] + + def async_worker_enabled(self): + return hasattr(self, "_enable_async_worker") and self._enable_async_worker + + def async_worker_start(self): + assert self.async_worker_enabled() + if not self._async_worker_active(): + + def _async_worker_initializer(device_id): + # The current device is set per thread, so we need to set it + # again here + torch.cuda.set_device(device_id) + # Submit the host copies in a separate stream to prevent the + # blocking copies from gating subsequent async work + torch.cuda.set_stream(torch.cuda.Stream()) + + self._async_worker = futures.ThreadPoolExecutor( + max_workers=self.MAX_WORKERS, + initializer=_async_worker_initializer, + initargs=(torch.cuda.current_device(),), + ) + + def async_worker_stop(self): + assert self.async_worker_enabled() + if self._async_worker_active(): + self._async_worker.shutdown(wait=True) + self._async_worker = None + + @torch.inference_mode() + def _async_worker_run(self, ready: torch.cuda.Event, func, /, *args, **kwargs): + # Make sure the async work takes place after all prior operations on + # the primary stream. synchronize() is intentionally chosen instead of + # wait() here; otherwise, blocking copies will stall subsequent CUDA + # API calls on the main thread + ready.synchronize() + + # Do the work + result = func(*args, **kwargs) + + # Work submitted to the async worker is expected to block at the end, + # consistent with the semantics of futures; make sure that we wait for + # everything to complete + torch.cuda.current_stream().synchronize() + + return result + + def _async_worker_submit(self, func, /, *args, **kwargs): + if self._async_worker_active(): + # Record an event on the main thread/stream that we will + # synchronize with on the worker thread/stream + ready = torch.cuda.Event() + ready.record() + + # Submit the async work + result = self._async_worker.submit(self._async_worker_run, ready, func, *args, **kwargs) + + # Save the future, so that we can await it later + self._async_worker_futures.append(result) + + return result + else: + # If the async worker is not in use, just execute the function + return func(*args, **kwargs) + + def _copy_to_host(self, src: torch.Tensor) -> torch.Tensor: + dest = torch.empty_like(src, device="cpu", pin_memory=True) + self._async_worker_submit(dest.copy_, src, non_blocking=True) + return dest + + def _sampler_event_get(self) -> SamplerEvent: + cuda_event = torch.cuda.Event() + cuda_event.record() + + # Transfer ownership to worker_futures and re-initialize + if self._async_worker_active(): + worker_futures = self._async_worker_futures + self._async_worker_futures = [] + else: + worker_futures = None + + return SamplerEvent(cuda_event=cuda_event, worker_futures=worker_futures) + + +class TorchSampler(Sampler, AsyncWorkerMixin): SampleState = SampleStateTorch @override @@ -616,6 +727,7 @@ class Args: max_beam_width: int max_total_draft_tokens: int disable_flash_infer_sampling: bool = False + enable_async_worker: bool = False def __init__(self, args: Args): self.max_seq_len = args.max_seq_len @@ -662,6 +774,8 @@ def __init__(self, args: Args): self._global_seed = 42 self._generator = None + self._async_worker_init(args.enable_async_worker) + def get_generator(self, device: torch.device) -> torch.Generator: """Get a deterministic generator for the specified device. @@ -1099,10 +1213,9 @@ def sample_async( self._write_finish_reasons( requests, finish_reasons=finish_reasons, seq_slots=seq_slots, new_tokens=new_tokens ) - finish_reasons_host = finish_reasons.to(device="cpu", non_blocking=True) + finish_reasons_host = self._copy_to_host(finish_reasons) - sampler_event = torch.cuda.Event() - sampler_event.record() + sampler_event = self._sampler_event_get() return SampleStateTorch( scheduled_requests=scheduled_requests, device=SampleStateTensors(new_tokens=new_tokens), @@ -1362,7 +1475,7 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool: new_tokens_cuda.view(-1, *new_tokens_cuda.shape[2:])[:, beam, ...].scatter_( 0, batch_dest_indices_1d_cuda, batch_next_tokens_cuda_int ) - new_tokens_host = new_tokens_cuda.to("cpu", non_blocking=True) + new_tokens_host = self._copy_to_host(new_tokens_cuda) return new_tokens_host @@ -1709,10 +1822,8 @@ def _process_requests( logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 ) # Use a single D2H copy to reduce overheads - topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True) - topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True) - topk_vals.copy_(topk_vals_cuda, non_blocking=True) - topk_indices.copy_(topk_indices_cuda, non_blocking=True) + topk_vals = self._copy_to_host(topk_vals_cuda) + topk_indices = self._copy_to_host(topk_indices_cuda) current_offset = 0 for req_id, steps in zip( logprobs_req_indices, req_num_steps[logprobs_req_indices].tolist() @@ -1791,7 +1902,7 @@ class SampleStateTRTLLM(SampleState): host: Optional[SampleStateTensorsHostTRTLLM] = None -class TRTLLMSampler(Sampler): +class TRTLLMSampler(Sampler, AsyncWorkerMixin): MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding SampleState = SampleStateTRTLLM @@ -1811,6 +1922,7 @@ def __init__( max_beam_width: int, decoding_config: Optional[DecodingConfig] = None, kv_cache_config: Optional[KvCacheConfig] = None, + enable_async_worker: bool = False, ): vocab_size = model.config.vocab_size num_hidden_layers = model.config.num_hidden_layers @@ -1861,6 +1973,8 @@ def __init__( self._initialize_store() self._instantiate_algorithms() + self._async_worker_init(enable_async_worker) + def _initialize_store(self): torch_stream = torch.cuda.current_stream().cuda_stream cuda_stream = CudaStream(torch_stream) @@ -2018,17 +2132,17 @@ def sample_async( finalize_events[request.request_id] = self._finalize_request(request, False) elif request.streaming: finalize_events[request.request_id] = self._finalize_request(request, True) - gathered_ids = self.store["decoder_state"].gathered_ids.to("cpu", non_blocking=True) - new_output_tokens = self.store["decoder_state"].all_new_tokens.to("cpu", non_blocking=True) - finished_sum = self.store["decoder_state"].finished_sum.to("cpu", non_blocking=True) - finish_reasons = self.store["decoder_state"].finish_reasons.to("cpu", non_blocking=True) - sequence_lengths = self.store["decoder_state"].sequence_lengths.to("cpu", non_blocking=True) + gathered_ids = self._copy_to_host(self.store["decoder_state"].gathered_ids) + new_output_tokens = self._copy_to_host(self.store["decoder_state"].all_new_tokens) + finished_sum = self._copy_to_host(self.store["decoder_state"].finished_sum) + finish_reasons = self._copy_to_host(self.store["decoder_state"].finish_reasons) + sequence_lengths = self._copy_to_host(self.store["decoder_state"].sequence_lengths) log_probs = None cum_log_probs = None if any(request.py_return_log_probs for request in scheduled_requests.all_requests()): - log_probs = self.store["decoder_state"].log_probs.to("cpu", non_blocking=True) - cum_log_probs = self.store["decoder_state"].cum_log_probs.to("cpu", non_blocking=True) + log_probs = self._copy_to_host(self.store["decoder_state"].log_probs) + cum_log_probs = self._copy_to_host(self.store["decoder_state"].cum_log_probs) device = SampleStateTensors(new_tokens=self.store["decoder_state"].all_new_tokens) @@ -2042,8 +2156,7 @@ def sample_async( gathered_ids=gathered_ids, ) - sampler_event = torch.cuda.Event() - sampler_event.record() + sampler_event = self._sampler_event_get() self.micro_batch_idx = (self.micro_batch_idx + 1) % self.num_micro_batches diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 189b96d8d66..19d877ef15b 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -1175,6 +1175,50 @@ def set_prometheus_multiproc_dir() -> object: f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") +def confidential_compute_enabled() -> bool: + """ + Query NVML for the confidential compute state + """ + + cc_enabled = False + + try: + # Init + import pynvml + pynvml.nvmlInit() + + # Hopper and newer supports a more nuanced query of confidential + # compute settings + cc_settings = pynvml.c_nvmlSystemConfComputeSettings_v1_t() + if (pynvml.nvmlSystemGetConfComputeSettings(cc_settings) == + pynvml.NVML_SUCCESS): + cc_enabled = (cc_settings.ccFeature + == pynvml.NVML_CC_SYSTEM_FEATURE_ENABLED + or cc_settings.multiGpuMode + == pynvml.NVML_CC_SYSTEM_MULTIGPU_PROTECTED_PCIE + or cc_settings.multiGpuMode + == pynvml.NVML_CC_SYSTEM_MULTIGPU_NVLE) + except pynvml.NVMLError_NotSupported: + # Simple query for older GPUs + try: + cc_state = pynvml.nvmlSystemGetConfComputeState() + cc_enabled = ( + cc_state.ccFeature == pynvml.NVML_CC_SYSTEM_FEATURE_ENABLED) + except Exception as e: + logger.error(f"Error querying confidential compute state: {str(e)}") + except Exception as e: + logger.error(f"Error querying confidential compute state: {str(e)}") + finally: + # Shutdown + try: + pynvml.nvmlShutdown() + except: + # Ignore shutdown errors + pass + + return cc_enabled + + P = ParamSpec("P") diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 49a4c3486d4..7d741bafa0b 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2565,6 +2565,11 @@ class TorchLlmArgs(BaseLlmArgs): "The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. Defaults to auto, which will use TorchSampler unless BeamSearch is requested.", status="beta") + enable_sampler_async_worker: bool = Field( + default=False, + description="Enable the async worker in the sampler for D2H copies.", + status="prototype") + enable_iter_perf_stats: bool = Field( default=False, description="Enable iteration performance statistics.", diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 4afc0d26bf6..1520fab8f83 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -232,10 +232,13 @@ def test_fp8_llm_sampler(self): @skip_pre_hopper @parametrize_with_ids("overlap_scheduler", [True, False]) @parametrize_with_ids("eagle3_one_model", [True, False]) - def test_eagle3(self, overlap_scheduler, eagle3_one_model): + @parametrize_with_ids("sampler_async_worker", [True, False]) + def test_eagle3(self, overlap_scheduler, eagle3_one_model, + sampler_async_worker): pytorch_config = dict( max_batch_size= 1, # add max_batch_size to avoid error in overlap scheduler + enable_sampler_async_worker=sampler_async_worker, disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig(max_batch_size=1, enable_padding=True), @@ -398,6 +401,7 @@ def test_auto_spec_decode(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @parametrize_with_ids("sampler_async_worker", [True, False]) @parametrize_with_ids("disable_overlap_scheduler", [False, True]) @parametrize_with_ids( "enable_cuda_graph,enable_padding", @@ -407,7 +411,8 @@ def test_auto_spec_decode(self): (True, True), # CUDA Graph with padding ]) def test_auto_dtype_beam_search(self, enable_cuda_graph, enable_padding, - disable_overlap_scheduler): + disable_overlap_scheduler, + sampler_async_worker): max_beam_width = 2 sampling_params = SamplingParams(n=max_beam_width, best_of=max_beam_width, @@ -432,6 +437,7 @@ def test_auto_dtype_beam_search(self, enable_cuda_graph, enable_padding, max_batch_size=max_beam_width, max_seq_len=2048, max_beam_width=max_beam_width, + enable_sampler_async_worker=sampler_async_worker, disable_overlap_scheduler=disable_overlap_scheduler, cuda_graph_config=cuda_graph_config, ) as llm: @@ -441,6 +447,7 @@ def test_auto_dtype_beam_search(self, enable_cuda_graph, enable_padding, extra_acc_spec="beam_width=2") @skip_pre_hopper + @parametrize_with_ids("sampler_async_worker", [True, False]) @parametrize_with_ids("disable_overlap_scheduler", [False, True]) @parametrize_with_ids( "enable_cuda_graph,enable_padding", @@ -450,7 +457,7 @@ def test_auto_dtype_beam_search(self, enable_cuda_graph, enable_padding, (True, True), # CUDA Graph with padding ]) def test_fp8_beam_search(self, enable_cuda_graph, enable_padding, - disable_overlap_scheduler): + disable_overlap_scheduler, sampler_async_worker): model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8" max_beam_width = 2 sampling_params = SamplingParams(n=max_beam_width, @@ -476,6 +483,7 @@ def test_fp8_beam_search(self, enable_cuda_graph, enable_padding, max_seq_len=2048, max_beam_width=max_beam_width, disable_overlap_scheduler=disable_overlap_scheduler, + enable_sampler_async_worker=sampler_async_worker, cuda_graph_config=cuda_graph_config, ) @@ -506,14 +514,17 @@ def test_fp8_prequantized(self): @skip_pre_hopper @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize("sampler_async_worker", [True, False]) @pytest.mark.parametrize("disable_overlap_scheduler", [True, False]) @pytest.mark.parametrize("pp_size", [2, 4], ids=["pp2", "pp4"]) - def test_return_logits_pp(self, pp_size, disable_overlap_scheduler): + def test_return_logits_pp(self, pp_size, disable_overlap_scheduler, + sampler_async_worker): prompts = ["A B C"] llm = LLM(model=self.MODEL_PATH, pipeline_parallel_size=pp_size, - disable_overlap_scheduler=disable_overlap_scheduler) + disable_overlap_scheduler=disable_overlap_scheduler, + enable_sampler_async_worker=sampler_async_worker) sampling_params = SamplingParams(max_tokens=8, return_context_logits=True, @@ -1478,6 +1489,7 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn, @pytest.mark.skip_less_device(4) @skip_pre_hopper @skip_ray + @parametrize_with_ids("sampler_async_worker", [True, False]) @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler", [(False, False, False, False), @@ -1493,7 +1505,8 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn, ids=["tp4", "ep4", "tp2pp2", "pp4"]) def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, attention_dp, cuda_graph, - overlap_scheduler, torch_compile): + overlap_scheduler, torch_compile, + sampler_async_worker): if torch_compile and pp_size > 1: pytest.skip("PP with torch.compile is not supported yet.") kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) @@ -1508,6 +1521,7 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn, torch_compile_config=torch_compile_config, moe_config=MoeConfig( backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"), + enable_sampler_async_worker=sampler_async_worker, ) if fp8kv: diff --git a/tests/integration/test_lists/qa/llm_digits_func.txt b/tests/integration/test_lists/qa/llm_digits_func.txt index 8cfe98fe119..30e3f223846 100644 --- a/tests/integration/test_lists/qa/llm_digits_func.txt +++ b/tests/integration/test_lists/qa/llm_digits_func.txt @@ -16,8 +16,9 @@ test_e2e.py::test_ptp_quickstart_advanced[Mistral-Nemo-12b-Base-Mistral-Nemo-Bas test_e2e.py::test_ptp_quickstart_advanced[DeepSeek-R1-Distill-Qwen-32B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-32B] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_llm_sampler -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=False-overlap_scheduler=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=True-overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index eb56be01ce5..8ab1078d9d3 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -391,8 +391,9 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4 accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=FLASHINFER] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_llm_sampler -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=True-overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=False-overlap_scheduler=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] @@ -404,18 +405,20 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=False] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=True] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2 accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4 @@ -581,6 +584,7 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=True] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index 6be44cf20c8..981a94a8a30 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -118,26 +118,29 @@ accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4 accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=FLASHINFER] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=False-overlap_scheduler=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=True-overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_llm_sampler accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=True] accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype diff --git a/tests/integration/test_lists/qa/llm_function_l20.txt b/tests/integration/test_lists/qa/llm_function_l20.txt index 5015e7ee15c..772e39a683b 100644 --- a/tests/integration/test_lists/qa/llm_function_l20.txt +++ b/tests/integration/test_lists/qa/llm_function_l20.txt @@ -19,9 +19,11 @@ accuracy/test_llm_api.py::TestMistralNemo12B::test_fp8 accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=FLASHINFER] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_llm_sampler -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=True-overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=False-overlap_scheduler=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] diff --git a/tests/integration/test_lists/qa/llm_function_nim.txt b/tests/integration/test_lists/qa/llm_function_nim.txt index 40daaa151fc..17229ea5a61 100644 --- a/tests/integration/test_lists/qa/llm_function_nim.txt +++ b/tests/integration/test_lists/qa/llm_function_nim.txt @@ -211,18 +211,20 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_auto_dtype_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=False-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=False-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True-sampler_async_worker=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=True] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4 accuracy/test_llm_api_pytorch.py::TestMistral7B::test_auto_dtype diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 9a875ff99d8..f681aa7c3bb 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -131,17 +131,21 @@ l0_dgx_h100: - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype1] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-sampler_async_worker=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-sampler_async_worker=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True-sampler_async_worker=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True-sampler_async_worker=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=0] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=2] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb diff --git a/tests/integration/test_lists/test-db/l0_dgx_h200.yml b/tests/integration/test_lists/test-db/l0_dgx_h200.yml index efacea5814d..a398d55037c 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h200.yml @@ -50,27 +50,31 @@ l0_dgx_h200: stage: post_merge backend: pytorch tests: - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=True-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-sampler_async_worker=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-sampler_async_worker=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False-sampler_async_worker=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 540a8580605..90886c4aded 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -56,8 +56,9 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=TRTLLM-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] - - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=False-overlap_scheduler=False] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=False-eagle3_one_model=True-overlap_scheduler=True] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index c08ec37c119..c2f43f93865 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -119,6 +119,10 @@ methods: annotation: Union[str, tensorrt_llm.llmapi.llm_args.SamplerType] default: auto status: beta + enable_sampler_async_worker: + annotation: bool + default: False + status: prototype enable_iter_perf_stats: annotation: bool default: False