@@ -945,15 +945,11 @@ def seq_slot_assignment(
945945 def mock_requests (
946946 self ,
947947 sampling_params_list : list [SamplingParams ],
948- with_draft_logits : bool ,
949- vocab_size : int ,
950948 seq_slot_assignment : tuple [list [int ], int ],
951949 draft_lens : list [int ],
952950 ) -> ScheduledRequests :
953951 return self ._build_mock_requests (
954952 sampling_params_list = sampling_params_list ,
955- with_draft_logits = with_draft_logits ,
956- vocab_size = vocab_size ,
957953 seq_slot_assignment = seq_slot_assignment ,
958954 draft_lens = draft_lens ,
959955 )
@@ -962,8 +958,6 @@ def _build_mock_requests(
962958 self ,
963959 sampling_params_list : list [SamplingParams ],
964960 * ,
965- with_draft_logits : bool ,
966- vocab_size : int ,
967961 seq_slot_assignment : tuple [list [int ], int ],
968962 draft_lens : list [int ],
969963 ) -> ScheduledRequests :
@@ -975,21 +969,9 @@ def __init__(
975969 self ,
976970 sampling_params_list : list [SamplingParams ],
977971 * ,
978- with_draft_logits : bool ,
979972 draft_lens : list [int ],
980973 ):
981974 self ._sampling_params_list = sampling_params_list
982- self ._with_draft_logits = with_draft_logits
983-
984- def _attach_draft_logits (req : LlmRequest ) -> LlmRequest :
985- draft_len = len (req .py_draft_tokens )
986- if draft_len and with_draft_logits :
987- req .py_draft_logits = torch .testing .make_tensor ( # type: ignore
988- (draft_len , vocab_size ),
989- dtype = torch .float32 ,
990- device = "cuda" ,
991- )
992- return req
993975
994976 # NB:
995977 # - stop words are tested in test_write_finish_reasons
@@ -999,24 +981,22 @@ def _attach_draft_logits(req: LlmRequest) -> LlmRequest:
999981 # - py_return_log_probs is tested elsewhere
1000982 # - code paths gated by py_return_context_logits tested in test_select_generated_logits
1001983 self ._gen_requests = [
1002- _attach_draft_logits (
1003- LlmRequest (
1004- request_id = seq_slot ,
1005- max_new_tokens = (2 * draft_len ), # not used by tested code
1006- input_tokens = [12 ], # not used by tested code
1007- sampling_config = SamplingConfig (sampling_params ._get_sampling_config ()),
1008- seq_slot = seq_slot ,
1009- is_streaming = False , # not relevant for tested code
1010- draft_tokens = ( # 'len(.py_draft_tokens)' is inspected by get_draft_token_length
1011- torch .testing .make_tensor (
1012- (draft_len ,),
1013- dtype = torch .int32 ,
1014- device = "cpu" ,
1015- ).tolist ()
1016- if draft_len
1017- else None
1018- ),
1019- )
984+ LlmRequest (
985+ request_id = seq_slot ,
986+ max_new_tokens = (2 * draft_len ), # not used by tested code
987+ input_tokens = [12 ], # not used by tested code
988+ sampling_config = SamplingConfig (sampling_params ._get_sampling_config ()),
989+ seq_slot = seq_slot ,
990+ is_streaming = False , # not relevant for tested code
991+ draft_tokens = ( # 'len(.py_draft_tokens)' is inspected by get_draft_token_length
992+ torch .testing .make_tensor (
993+ (draft_len ,),
994+ dtype = torch .int32 ,
995+ device = "cpu" ,
996+ ).tolist ()
997+ if draft_len
998+ else None
999+ ),
10201000 )
10211001 for sampling_params , seq_slot , draft_len in zip (
10221002 sampling_params_list , seq_slots , draft_lens
@@ -1040,9 +1020,7 @@ def all_requests(self) -> list[LlmRequest]:
10401020 with torch .inference_mode (True ):
10411021 return cast (
10421022 ScheduledRequests ,
1043- ScheduledRequestsMock (
1044- sampling_params_list , with_draft_logits = with_draft_logits , draft_lens = draft_lens
1045- ),
1023+ ScheduledRequestsMock (sampling_params_list , draft_lens = draft_lens ),
10461024 )
10471025
10481026 @pytest .fixture (scope = "function" )
@@ -1184,20 +1162,17 @@ def test_backend_selection(
11841162 "max_draft_len" ,
11851163 "draft_lens" ,
11861164 "sampling_params_list" ,
1187- "with_draft_logits" ,
11881165 "params_label" ,
11891166 "allow_zero_draft_len" ,
11901167 "vocab_size" ,
11911168 ),
11921169 [
1193- # NB: with_draft_logits=True and non-zero draft len ensures that
1194- # LlmRequest.py_target_probs is set.
1170+ # NB: non-zero draft len ensures that LlmRequest.py_target_probs is set.
11951171 pytest .param (
11961172 use_flashinfer ,
11971173 3 ,
11981174 [3 ] * len (sampling_params_list ),
11991175 sampling_params_list ,
1200- True ,
12011176 params_label ,
12021177 False ,
12031178 vocab_size ,
@@ -1225,7 +1200,6 @@ def test_probs(
12251200 allow_zero_draft_len : bool , # used by fixtures
12261201 sampling_params_list : list [SamplingParams ],
12271202 seq_slot_assignment : tuple [list [int ], int ],
1228- with_draft_logits : bool ,
12291203 ):
12301204 """Validate probabilities returned by sample_async.
12311205
@@ -1255,9 +1229,7 @@ def _uut_provider(is_warmup: bool) -> Generator[Callable[[], None], None, None]:
12551229 # requests.
12561230 uut_mock_requests = self ._build_mock_requests (
12571231 sampling_params_list = sampling_params_list ,
1258- vocab_size = vocab_size ,
12591232 seq_slot_assignment = seq_slot_assignment ,
1260- with_draft_logits = with_draft_logits ,
12611233 draft_lens = draft_lens ,
12621234 )
12631235 else :
@@ -1427,11 +1399,8 @@ def _compute_probs(
14271399 )
14281400 mock_requests_with_probs = self ._build_mock_requests (
14291401 sampling_params_list = sampling_params_list ,
1430- vocab_size = vocab_size ,
14311402 seq_slot_assignment = seq_slot_assignment ,
1432- # NB: with_draft_logits=True and non-zero draft len ensures that
1433- # LlmRequest.py_target_probs is set.
1434- with_draft_logits = True ,
1403+ # NB: non-zero draft len ensures that LlmRequest.py_target_probs is set.
14351404 draft_lens = ([draft_len_with_probs ] * len (sampling_params_list )),
14361405 )
14371406 # zero-pad logits to draft_len_with_probs
@@ -1818,6 +1787,12 @@ def _validate_token_frequencies(
18181787
18191788 # Perform G-test (asymptotically approximated by Pearson's chi-square test) to
18201789 # check that sampled tokens are consistent with the expected probs.
1790+ #
1791+ # NB: Need to use FP64 to avoid negative test statistic values.
1792+ test_token_counts_ma = test_token_counts_ma .astype (np .float64 )
1793+ test_expected_counts_ma = test_expected_counts_ma .astype (np .float64 )
1794+ test_expected_counts_ma /= test_expected_counts_ma .sum (axis = - 1 , keepdims = True )
1795+ test_expected_counts_ma *= num_samples
18211796 test_result = power_divergence (
18221797 f_obs = test_token_counts_ma ,
18231798 f_exp = test_expected_counts_ma ,
@@ -1847,7 +1822,6 @@ def _validate_token_frequencies(
18471822 "use_flashinfer" ,
18481823 "max_draft_len" ,
18491824 "sampling_params_list" ,
1850- "with_draft_logits" ,
18511825 "allow_zero_draft_len" ,
18521826 "bypass_sampling" ,
18531827 "vocab_size" ,
@@ -1857,7 +1831,6 @@ def _validate_token_frequencies(
18571831 use_flashinfer ,
18581832 max_draft_len ,
18591833 sampling_params_list ,
1860- with_draft_logits ,
18611834 allow_zero_draft_len ,
18621835 # Run full sampling test only for uniform batches, with/without probs, but skip
18631836 # sampling statistics when varying draft lens etc. to validate batch handling:
@@ -1868,22 +1841,20 @@ def _validate_token_frequencies(
18681841 id = (
18691842 f"{ 'FlashInfer' if use_flashinfer else 'Torch' } "
18701843 f"-draft_len={ 0 if allow_zero_draft_len else 1 } ..{ max_draft_len } "
1871- f"-return_probs= { with_draft_logits } - { params_label } "
1844+ f"-{ params_label } "
18721845 ),
18731846 )
18741847 # https://stackoverflow.com/a/75421799, does not work with nested loops
18751848 for (
18761849 use_flashinfer ,
18771850 is_mixed ,
1878- with_draft_logits ,
18791851 max_draft_len ,
18801852 allow_zero_draft_len ,
18811853 _build_test_cases ,
18821854 vocab_size ,
18831855 ) in product (
18841856 [False , True ],
18851857 [False , True ],
1886- [True , False ],
18871858 [0 , 3 ],
18881859 [False , True ],
18891860 [_build_test_cases ],
@@ -1895,8 +1866,7 @@ def _validate_token_frequencies(
18951866 include_uniform = (not is_mixed ),
18961867 include_mixed = is_mixed ,
18971868 )
1898- if (allow_zero_draft_len or max_draft_len > 0 )
1899- and (not with_draft_logits or max_draft_len > 0 )
1869+ if allow_zero_draft_len or max_draft_len > 0
19001870 ],
19011871 )
19021872 def test_samples (
@@ -1908,7 +1878,6 @@ def test_samples(
19081878 vocab_size : int ,
19091879 sampling_params_list : list [SamplingParams ],
19101880 seq_slot_assignment : tuple [list [int ], int ],
1911- with_draft_logits : bool ,
19121881 max_draft_len : int ,
19131882 use_flashinfer : bool ,
19141883 allow_zero_draft_len : bool , # used by fixtures
@@ -2038,7 +2007,7 @@ def _uut(res=res):
20382007 probs = probs [: (draft_len + 1 )]
20392008
20402009 # check probs are returned only when needed
2041- should_return_probs = draft_len and with_draft_logits
2010+ should_return_probs = bool ( draft_len )
20422011 assert (
20432012 hasattr (req , "py_target_probs" ) and req .py_target_probs is not None
20442013 ) == should_return_probs
0 commit comments