Skip to content

Commit d548210

Browse files
authored
Add util function testing._is_single_token_bitmask (#294)
This PR adds a function _is_single_token_bitmask to check if the bitmask only allows one token.
1 parent 536b372 commit d548210

File tree

8 files changed

+114
-4
lines changed

8 files changed

+114
-4
lines changed

cpp/grammar_matcher.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ void _DebugGetMaskedTokensFromBitmask(
6666
}
6767
}
6868

69+
std::pair<bool, int> _IsSingleTokenBitmask(const DLTensor& bitmask, int vocab_size, int index) {
70+
int32_t* data_ptr = CheckAndGetBitmaskPtr(bitmask, vocab_size, index);
71+
DynamicBitset bitset(vocab_size, reinterpret_cast<uint32_t*>(data_ptr));
72+
if (bitset.Count() == 1) {
73+
return std::make_pair(true, bitset.FindFirstOne());
74+
} else {
75+
return std::make_pair(false, -1);
76+
}
77+
}
78+
6979
void ApplyTokenBitmaskInplaceCPU(
7080
DLTensor* logits,
7181
const DLTensor& bitmask,

cpp/nanobind/nanobind.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ NB_MODULE(xgrammar_bindings, m) {
210210
)
211211
.def("_regex_to_ebnf", &RegexToEBNF)
212212
.def("_ebnf_to_grammar_no_normalization", &_EBNFToGrammarNoNormalization)
213-
.def("_get_masked_tokens_from_bitmask", &Matcher_DebugGetMaskedTokensFromBitmask)
213+
.def("_get_masked_tokens_from_bitmask", &Testing_DebugGetMaskedTokensFromBitmask)
214+
.def("_is_single_token_bitmask", &Testing_IsSingleTokenBitmask)
214215
.def("_get_allow_empty_rule_ids", &GetAllowEmptyRuleIds)
215216
.def(
216217
"_generate_range_regex",

cpp/nanobind/python_methods.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ bool GrammarMatcher_FillNextTokenBitmask(
6060
return matcher.FillNextTokenBitmask(&bitmask_dltensor, index, debug_print);
6161
}
6262

63-
std::vector<int> Matcher_DebugGetMaskedTokensFromBitmask(
63+
std::vector<int> Testing_DebugGetMaskedTokensFromBitmask(
6464
intptr_t token_bitmask_ptr, std::vector<int64_t> shape, int32_t vocab_size, int32_t index
6565
) {
6666
XGRAMMAR_CHECK(shape.size() == 1 || shape.size() == 2) << "token_bitmask tensor must be 1D or 2D";
@@ -80,6 +80,24 @@ std::vector<int> Matcher_DebugGetMaskedTokensFromBitmask(
8080
return result;
8181
}
8282

83+
std::pair<bool, int> Testing_IsSingleTokenBitmask(
84+
intptr_t token_bitmask_ptr, std::vector<int64_t> shape, int32_t vocab_size, int32_t index
85+
) {
86+
XGRAMMAR_CHECK(shape.size() == 1 || shape.size() == 2) << "token_bitmask tensor must be 1D or 2D";
87+
88+
DLTensor bitmask_dltensor{
89+
reinterpret_cast<void*>(token_bitmask_ptr),
90+
DLDevice{kDLCPU, 0},
91+
static_cast<int32_t>(shape.size()),
92+
GetBitmaskDLType(),
93+
shape.data(),
94+
nullptr,
95+
0
96+
};
97+
98+
return _IsSingleTokenBitmask(bitmask_dltensor, vocab_size, index);
99+
}
100+
83101
void Kernels_ApplyTokenBitmaskInplaceCPU(
84102
intptr_t logits_ptr,
85103
std::pair<int64_t, int64_t> logits_shape,

cpp/nanobind/python_methods.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ bool GrammarMatcher_FillNextTokenBitmask(
3535
bool debug_print
3636
);
3737

38-
std::vector<int> Matcher_DebugGetMaskedTokensFromBitmask(
38+
std::vector<int> Testing_DebugGetMaskedTokensFromBitmask(
39+
intptr_t token_bitmask_ptr, std::vector<int64_t> shape, int32_t vocab_size, int32_t index
40+
);
41+
42+
std::pair<bool, int> Testing_IsSingleTokenBitmask(
3943
intptr_t token_bitmask_ptr, std::vector<int64_t> shape, int32_t vocab_size, int32_t index
4044
);
4145

cpp/support/dynamic_bitset.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,18 @@ class DynamicBitset {
134134
return *this;
135135
}
136136

137+
int FindFirstOne() const { return DoFindOneFrom(0); }
138+
139+
int FindNextOne(int pos) const {
140+
if (pos >= size_ - 1 || size_ == 0) return -1;
141+
++pos;
142+
int blk = pos / BITS_PER_BLOCK;
143+
int ind = pos % BITS_PER_BLOCK;
144+
uint32_t fore = data_[blk] >> ind;
145+
int result = fore ? pos + LowestBit(fore) : DoFindOneFrom(blk + 1);
146+
return result < size_ ? result : -1;
147+
}
148+
137149
int FindFirstZero() const { return DoFindZeroFrom(0); }
138150

139151
int FindNextZero(int pos) const {
@@ -210,6 +222,18 @@ class DynamicBitset {
210222
return position * BITS_PER_BLOCK + LowestBit(~data_[position]);
211223
}
212224

225+
int DoFindOneFrom(int first_block) const {
226+
int position = -1;
227+
for (int i = first_block; i < buffer_size_; ++i) {
228+
if (data_[i] != 0) {
229+
position = i;
230+
break;
231+
}
232+
}
233+
if (position == -1) return -1;
234+
return position * BITS_PER_BLOCK + LowestBit(data_[position]);
235+
}
236+
213237
// The size of the bitset.
214238
int size_;
215239
// The size of the buffer.

include/xgrammar/matcher.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ void _DebugGetMaskedTokensFromBitmask(
2626
std::vector<int>* rejected_tokens, const DLTensor& token_bitmask, int vocab_size, int index = 0
2727
);
2828

29+
std::pair<bool, int> _IsSingleTokenBitmask(const DLTensor& bitmask, int vocab_size, int index);
30+
2931
void ApplyTokenBitmaskInplaceCPU(
3032
DLTensor* logits,
3133
const DLTensor& bitmask,

python/xgrammar/testing.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,32 @@ def _get_masked_tokens_from_bitmask(
175175
)
176176

177177

178+
def _is_single_token_bitmask(
179+
bitmask: torch.Tensor, vocab_size: int, index: int = 0
180+
) -> Tuple[bool, int]:
181+
"""Check if the bitmask is a single token bitmask.
182+
183+
Parameters
184+
----------
185+
bitmask : torch.Tensor
186+
The bitmask to check. Should be on CPU.
187+
vocab_size : int
188+
The size of the vocabulary.
189+
index : int, default: 0
190+
The index of the bitmask.
191+
192+
Returns
193+
-------
194+
is_single_token : bool
195+
True if the bitmask is a single token bitmask, False otherwise.
196+
token_id : int
197+
The id of the token if the bitmask is a single token bitmask, -1 otherwise.
198+
"""
199+
return _core.testing._is_single_token_bitmask(
200+
bitmask.data_ptr(), list(bitmask.shape), vocab_size, index
201+
)
202+
203+
178204
def _bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor:
179205
"""Get the bitmask from bool mask. If the bool mask does not align with the 32-bit block
180206
size, it will add extra 1 paddings.

tests/python/test_token_bitmask_operations.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
import torch
99

1010
import xgrammar as xgr
11-
from xgrammar.testing import _bool_mask_to_bitmask, _get_masked_tokens_from_bitmask
11+
from xgrammar.testing import (
12+
_bool_mask_to_bitmask,
13+
_get_masked_tokens_from_bitmask,
14+
_is_single_token_bitmask,
15+
)
1216

1317
_is_cuda_available = torch.cuda.is_available()
1418
_is_mps_available = torch.backends.mps.is_available()
@@ -38,6 +42,27 @@ def test_get_masked_tokens_from_bitmask(token_mask_size: int, index: int):
3842
assert _get_masked_tokens_from_bitmask(bitmask, token_mask_size, index) == expected
3943

4044

45+
def test_is_single_token_bitmask():
46+
batch = 2
47+
batch_index = 1
48+
vocab_size = 1024
49+
token_id = 100
50+
51+
bool_mask = torch.zeros(batch, vocab_size, dtype=torch.bool)
52+
bitmask = _bool_mask_to_bitmask(bool_mask)
53+
assert _is_single_token_bitmask(bitmask, vocab_size, batch_index) == (False, -1)
54+
bool_mask[batch_index, token_id] = True
55+
bitmask = _bool_mask_to_bitmask(bool_mask)
56+
assert _is_single_token_bitmask(bitmask, vocab_size, batch_index) == (True, token_id)
57+
bool_mask[batch_index, token_id + 1] = True
58+
bitmask = _bool_mask_to_bitmask(bool_mask)
59+
assert _is_single_token_bitmask(bitmask, vocab_size, batch_index) == (False, -1)
60+
61+
62+
test_is_single_token_bitmask()
63+
exit()
64+
65+
4166
@pytest.mark.parametrize("device", ("cpu", "cuda"))
4267
def test_apply_token_bitmask_inplace(device: str):
4368
if device == "cuda" and not _is_cuda_available:

0 commit comments

Comments
 (0)