|
8 | 8 | import torch |
9 | 9 |
|
10 | 10 | import xgrammar as xgr |
11 | | -from xgrammar.testing import _bool_mask_to_bitmask, _get_masked_tokens_from_bitmask |
| 11 | +from xgrammar.testing import ( |
| 12 | + _bool_mask_to_bitmask, |
| 13 | + _get_masked_tokens_from_bitmask, |
| 14 | + _is_single_token_bitmask, |
| 15 | +) |
12 | 16 |
|
13 | 17 | _is_cuda_available = torch.cuda.is_available() |
14 | 18 | _is_mps_available = torch.backends.mps.is_available() |
@@ -38,6 +42,27 @@ def test_get_masked_tokens_from_bitmask(token_mask_size: int, index: int): |
38 | 42 | assert _get_masked_tokens_from_bitmask(bitmask, token_mask_size, index) == expected |
39 | 43 |
|
40 | 44 |
|
| 45 | +def test_is_single_token_bitmask(): |
| 46 | + batch = 2 |
| 47 | + batch_index = 1 |
| 48 | + vocab_size = 1024 |
| 49 | + token_id = 100 |
| 50 | + |
| 51 | + bool_mask = torch.zeros(batch, vocab_size, dtype=torch.bool) |
| 52 | + bitmask = _bool_mask_to_bitmask(bool_mask) |
| 53 | + assert _is_single_token_bitmask(bitmask, vocab_size, batch_index) == (False, -1) |
| 54 | + bool_mask[batch_index, token_id] = True |
| 55 | + bitmask = _bool_mask_to_bitmask(bool_mask) |
| 56 | + assert _is_single_token_bitmask(bitmask, vocab_size, batch_index) == (True, token_id) |
| 57 | + bool_mask[batch_index, token_id + 1] = True |
| 58 | + bitmask = _bool_mask_to_bitmask(bool_mask) |
| 59 | + assert _is_single_token_bitmask(bitmask, vocab_size, batch_index) == (False, -1) |
| 60 | + |
| 61 | + |
| 62 | +test_is_single_token_bitmask() |
| 63 | +exit() |
| 64 | + |
| 65 | + |
41 | 66 | @pytest.mark.parametrize("device", ("cpu", "cuda")) |
42 | 67 | def test_apply_token_bitmask_inplace(device: str): |
43 | 68 | if device == "cuda" and not _is_cuda_available: |
|
0 commit comments