@@ -420,16 +420,22 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool)
420420
421421 @contextmanager
422422 def _test_runner (is_warmup : bool ) -> Generator [Callable [[], None ], None , None ]:
423+ draft_len_req1 = draft_len
424+ draft_len_req2 = draft_len + 1 # test with different draft lens
425+
423426 class ContextRequestMock :
424- def __init__ (self , return_context_logits : bool ):
427+ def __init__ (self , is_last_context_chunk : bool , return_context_logits : bool ):
428+ self .is_last_context_chunk = is_last_context_chunk
429+ self .py_draft_tokens = torch .tensor ([], dtype = torch .int32 , device = device )
425430 self ._return_context_logits = return_context_logits
426431
427432 @property
428433 def py_return_context_logits (self ) -> bool :
429434 return self ._return_context_logits
430435
431436 class GenRequestMock :
432- pass
437+ def __init__ (self , draft_len : int ):
438+ self .py_draft_tokens = torch .empty (draft_len , dtype = torch .int32 , device = device )
433439
434440 class ScheduledRequestsMock :
435441 @property
@@ -438,9 +444,11 @@ def context_requests(self) -> list[LlmRequest]:
438444 [
439445 # NB: One request with py_return_context_logits is enough
440446 # to trigger tested code.
441- cast (LlmRequest , ContextRequestMock (True )),
442- cast (LlmRequest , ContextRequestMock (False )),
443- cast (LlmRequest , ContextRequestMock (True )),
447+ cast (LlmRequest , ContextRequestMock (True , True )),
448+ cast (LlmRequest , ContextRequestMock (True , False )),
449+ # This request is expected to be skipped
450+ cast (LlmRequest , ContextRequestMock (False , False )),
451+ cast (LlmRequest , ContextRequestMock (True , True )),
444452 ]
445453 if with_ctx
446454 else []
@@ -452,35 +460,37 @@ def generation_requests(self) -> list[LlmRequest]:
452460 # is not empty.
453461 return (
454462 [
455- cast (LlmRequest , GenRequestMock ()),
456- cast (LlmRequest , GenRequestMock ()),
463+ cast (LlmRequest , GenRequestMock (draft_len_req1 )),
464+ cast (LlmRequest , GenRequestMock (draft_len_req2 )),
457465 ]
458466 if with_gen
459467 else []
460468 )
461469
462- vocab_size = 12
470+ expected_num_generation_requests = with_ctx * 3 + with_gen * 2
463471
464472 num_context_logits_prefix_sum = [
465473 0 ,
466474 * (
467475 [
468476 100 + 1 , # context req. 1 (assume context len. 100)
469477 (100 + 1 ) + (0 + 1 ), # context req. 2 (not returning context)
470- (100 + 1 ) + (0 + 1 ) + (50 + 1 ), # context req. 3 (assume context len. 50)
478+ (100 + 1 ) + (0 + 1 ) + (0 + 1 ), # context req. 3 (not returning context)
479+ (100 + 1 )
480+ + (0 + 1 )
481+ + (0 + 1 )
482+ + (50 + 1 ), # context req. 4 (assume context len. 50)
471483 ]
472484 if with_ctx
473485 else []
474486 ),
475487 ]
476- draft_len_req1 = draft_len
477- draft_len_req2 = draft_len + 1 # test with different draft lens
478- req_num_generation_steps = [
488+ expected_req_num_generation_steps = [
479489 * (
480490 [
481491 1 , # context req. 1
482492 1 , # context req. 2
483- 1 , # context req. 3
493+ 1 , # context req. 4
484494 ]
485495 if with_ctx
486496 else []
@@ -494,12 +504,20 @@ def generation_requests(self) -> list[LlmRequest]:
494504 else []
495505 ),
496506 ]
497- req_num_generation_steps_tensor = torch .tensor (req_num_generation_steps , dtype = torch .int32 )
498- num_logits_to_keep = cast (int , req_num_generation_steps_tensor .sum ().item ())
507+ expected_req_num_generation_steps_tensor = torch .tensor (
508+ expected_req_num_generation_steps , dtype = torch .int32
509+ )
510+
511+ expected_req_offsets = torch .cumsum (expected_req_num_generation_steps_tensor , dim = 0 ).roll (1 )
512+ expected_req_offsets [0 ] = 0
513+
514+ # num_logits_to_keep = cast(int, req_num_generation_steps_tensor.sum().item())
499515 generation_requests_total_steps = (draft_len_req1 + 1 ) + (
500516 draft_len_req2 + 1
501517 ) # cf. req_num_generation_steps
502518
519+ vocab_size = 12
520+
503521 num_total_steps = num_context_logits_prefix_sum [- 1 ] + generation_requests_total_steps
504522 all_logits = torch .empty ((num_total_steps , vocab_size ))
505523
@@ -513,7 +531,8 @@ def generation_requests(self) -> list[LlmRequest]:
513531 expected_logit_indices += [
514532 100 , # gen logits from context req. 1
515533 101 , # gen logits from context req. 2
516- 152 , # gen logits from context req. 3
534+ # 102, # skipped gen logits from context req. 3
535+ 153 , # gen logits from context req. 4
517536 ]
518537 if with_gen :
519538 gen_logit_offset = num_context_logits_prefix_sum [- 1 ]
@@ -527,6 +546,8 @@ def generation_requests(self) -> list[LlmRequest]:
527546 ), # gen logits from gen. req. 2
528547 ]
529548
549+ expected_logits = all_logits [expected_logit_indices ]
550+
530551 @dataclass
531552 class UutResult :
532553 selected_requests : list [LlmRequest ]
@@ -542,22 +563,37 @@ class UutResultWrapper:
542563 res = UutResultWrapper ()
543564
544565 def _uut (res = res ):
545- selected_logits = TorchSampler ._select_generated_logits (
566+ (
567+ selected_requests ,
568+ req_num_generation_steps_list ,
569+ req_num_generation_steps ,
570+ req_offsets ,
571+ selected_logits ,
572+ ) = TorchSampler ._select_generated_logits (
546573 cast (ScheduledRequests , ScheduledRequestsMock ()),
547574 all_logits_cuda ,
548- req_num_generation_steps = req_num_generation_steps_tensor ,
549575 num_context_logits_prefix_sum = num_context_logits_prefix_sum ,
550- generation_requests_total_steps = generation_requests_total_steps ,
551- num_logits_to_keep = num_logits_to_keep ,
552576 )
553- res .result = UutResult (selected_logits = selected_logits )
577+ res .result = UutResult (
578+ selected_requests = selected_requests ,
579+ req_num_generation_steps_list = req_num_generation_steps_list ,
580+ req_num_generation_steps = req_num_generation_steps ,
581+ req_offsets = req_offsets ,
582+ selected_logits = selected_logits ,
583+ )
554584
555585 yield _uut
556586
557- # Check logits
587+ # Check results
558588 assert res .result is not None
559- selected_logits = res .result .selected_logits
560- torch .testing .assert_close (selected_logits .to ("cpu" ), all_logits [expected_logit_indices ])
589+
590+ assert len (res .result .selected_requests ) == expected_num_generation_requests
591+ torch .testing .assert_close (
592+ res .result .req_num_generation_steps .to ("cpu" ), expected_req_num_generation_steps_tensor
593+ )
594+ assert res .result .req_num_generation_steps_list == expected_req_num_generation_steps
595+ torch .testing .assert_close (res .result .req_offsets .to ("cpu" ), expected_req_offsets )
596+ torch .testing .assert_close (res .result .selected_logits .to ("cpu" ), expected_logits )
561597
562598 _run_test_with_warmup (_test_runner , max_sync_s = 0.3 )
563599
0 commit comments