Skip to content

Commit 4a8ea56

Browse files
committed
[test] Fix test_select_generated_logits
Signed-off-by: Robin Kobus <[email protected]>
1 parent 8ce3bea commit 4a8ea56

File tree

1 file changed

+60
-24
lines changed

1 file changed

+60
-24
lines changed

tests/unittest/_torch/sampler/test_torch_sampler.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -420,16 +420,22 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool)
420420

421421
@contextmanager
422422
def _test_runner(is_warmup: bool) -> Generator[Callable[[], None], None, None]:
423+
draft_len_req1 = draft_len
424+
draft_len_req2 = draft_len + 1 # test with different draft lens
425+
423426
class ContextRequestMock:
424-
def __init__(self, return_context_logits: bool):
427+
def __init__(self, is_last_context_chunk: bool, return_context_logits: bool):
428+
self.is_last_context_chunk = is_last_context_chunk
429+
self.py_draft_tokens = torch.tensor([], dtype=torch.int32, device=device)
425430
self._return_context_logits = return_context_logits
426431

427432
@property
428433
def py_return_context_logits(self) -> bool:
429434
return self._return_context_logits
430435

431436
class GenRequestMock:
432-
pass
437+
def __init__(self, draft_len: int):
438+
self.py_draft_tokens = torch.empty(draft_len, dtype=torch.int32, device=device)
433439

434440
class ScheduledRequestsMock:
435441
@property
@@ -438,9 +444,11 @@ def context_requests(self) -> list[LlmRequest]:
438444
[
439445
# NB: One request with py_return_context_logits is enough
440446
# to trigger tested code.
441-
cast(LlmRequest, ContextRequestMock(True)),
442-
cast(LlmRequest, ContextRequestMock(False)),
443-
cast(LlmRequest, ContextRequestMock(True)),
447+
cast(LlmRequest, ContextRequestMock(True, True)),
448+
cast(LlmRequest, ContextRequestMock(True, False)),
449+
# This request is expected to be skipped
450+
cast(LlmRequest, ContextRequestMock(False, False)),
451+
cast(LlmRequest, ContextRequestMock(True, True)),
444452
]
445453
if with_ctx
446454
else []
@@ -452,35 +460,37 @@ def generation_requests(self) -> list[LlmRequest]:
452460
# is not empty.
453461
return (
454462
[
455-
cast(LlmRequest, GenRequestMock()),
456-
cast(LlmRequest, GenRequestMock()),
463+
cast(LlmRequest, GenRequestMock(draft_len_req1)),
464+
cast(LlmRequest, GenRequestMock(draft_len_req2)),
457465
]
458466
if with_gen
459467
else []
460468
)
461469

462-
vocab_size = 12
470+
expected_num_generation_requests = with_ctx * 3 + with_gen * 2
463471

464472
num_context_logits_prefix_sum = [
465473
0,
466474
*(
467475
[
468476
100 + 1, # context req. 1 (assume context len. 100)
469477
(100 + 1) + (0 + 1), # context req. 2 (not returning context)
470-
(100 + 1) + (0 + 1) + (50 + 1), # context req. 3 (assume context len. 50)
478+
(100 + 1) + (0 + 1) + (0 + 1), # context req. 3 (not returning context)
479+
(100 + 1)
480+
+ (0 + 1)
481+
+ (0 + 1)
482+
+ (50 + 1), # context req. 4 (assume context len. 50)
471483
]
472484
if with_ctx
473485
else []
474486
),
475487
]
476-
draft_len_req1 = draft_len
477-
draft_len_req2 = draft_len + 1 # test with different draft lens
478-
req_num_generation_steps = [
488+
expected_req_num_generation_steps = [
479489
*(
480490
[
481491
1, # context req. 1
482492
1, # context req. 2
483-
1, # context req. 3
493+
1, # context req. 4
484494
]
485495
if with_ctx
486496
else []
@@ -494,12 +504,20 @@ def generation_requests(self) -> list[LlmRequest]:
494504
else []
495505
),
496506
]
497-
req_num_generation_steps_tensor = torch.tensor(req_num_generation_steps, dtype=torch.int32)
498-
num_logits_to_keep = cast(int, req_num_generation_steps_tensor.sum().item())
507+
expected_req_num_generation_steps_tensor = torch.tensor(
508+
expected_req_num_generation_steps, dtype=torch.int32
509+
)
510+
511+
expected_req_offsets = torch.cumsum(expected_req_num_generation_steps_tensor, dim=0).roll(1)
512+
expected_req_offsets[0] = 0
513+
514+
# num_logits_to_keep = cast(int, req_num_generation_steps_tensor.sum().item())
499515
generation_requests_total_steps = (draft_len_req1 + 1) + (
500516
draft_len_req2 + 1
501517
) # cf. req_num_generation_steps
502518

519+
vocab_size = 12
520+
503521
num_total_steps = num_context_logits_prefix_sum[-1] + generation_requests_total_steps
504522
all_logits = torch.empty((num_total_steps, vocab_size))
505523

@@ -513,7 +531,8 @@ def generation_requests(self) -> list[LlmRequest]:
513531
expected_logit_indices += [
514532
100, # gen logits from context req. 1
515533
101, # gen logits from context req. 2
516-
152, # gen logits from context req. 3
534+
# 102, # skipped gen logits from context req. 3
535+
153, # gen logits from context req. 4
517536
]
518537
if with_gen:
519538
gen_logit_offset = num_context_logits_prefix_sum[-1]
@@ -527,6 +546,8 @@ def generation_requests(self) -> list[LlmRequest]:
527546
), # gen logits from gen. req. 2
528547
]
529548

549+
expected_logits = all_logits[expected_logit_indices]
550+
530551
@dataclass
531552
class UutResult:
532553
selected_requests: list[LlmRequest]
@@ -542,22 +563,37 @@ class UutResultWrapper:
542563
res = UutResultWrapper()
543564

544565
def _uut(res=res):
545-
selected_logits = TorchSampler._select_generated_logits(
566+
(
567+
selected_requests,
568+
req_num_generation_steps_list,
569+
req_num_generation_steps,
570+
req_offsets,
571+
selected_logits,
572+
) = TorchSampler._select_generated_logits(
546573
cast(ScheduledRequests, ScheduledRequestsMock()),
547574
all_logits_cuda,
548-
req_num_generation_steps=req_num_generation_steps_tensor,
549575
num_context_logits_prefix_sum=num_context_logits_prefix_sum,
550-
generation_requests_total_steps=generation_requests_total_steps,
551-
num_logits_to_keep=num_logits_to_keep,
552576
)
553-
res.result = UutResult(selected_logits=selected_logits)
577+
res.result = UutResult(
578+
selected_requests=selected_requests,
579+
req_num_generation_steps_list=req_num_generation_steps_list,
580+
req_num_generation_steps=req_num_generation_steps,
581+
req_offsets=req_offsets,
582+
selected_logits=selected_logits,
583+
)
554584

555585
yield _uut
556586

557-
# Check logits
587+
# Check results
558588
assert res.result is not None
559-
selected_logits = res.result.selected_logits
560-
torch.testing.assert_close(selected_logits.to("cpu"), all_logits[expected_logit_indices])
589+
590+
assert len(res.result.selected_requests) == expected_num_generation_requests
591+
torch.testing.assert_close(
592+
res.result.req_num_generation_steps.to("cpu"), expected_req_num_generation_steps_tensor
593+
)
594+
assert res.result.req_num_generation_steps_list == expected_req_num_generation_steps
595+
torch.testing.assert_close(res.result.req_offsets.to("cpu"), expected_req_offsets)
596+
torch.testing.assert_close(res.result.selected_logits.to("cpu"), expected_logits)
561597

562598
_run_test_with_warmup(_test_runner, max_sync_s=0.3)
563599

0 commit comments

Comments
 (0)