Skip to content

Commit 8814efc

Browse files
Update openai_gen.py to explain need of prioritization of tokens
1 parent 75a0f8b commit 8814efc

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/monitors4codegen/monitor_guided_decoding/openai_gen.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@ def openai_mgd(
3939
tokens_sort_key = {k:[0, 0] for k in tokenizer.all_token_ids}
4040

4141
# # TODO: Find a way to prioritize tokens to be blacklisted
42-
# # 1. The following code uses info about whether has a break char in it
42+
43+
# # Why prioritize? OpenAI allows applying logit_bias to upto 300 tokens, whereas the typical number of tokens in vocabulary is 50,000.
44+
# # Because of this, it is necessary to identify the top 300 tokens, that we think need to be either blacklisted, or whitelisted.
45+
# # This prioritization should be done taking into account what violating token is the model likely to predict in the next step.
46+
47+
# # Options for prioritization of tokens:
48+
# # 1. The following code uses info about whether the token has a break char in it
4349
# for token, token_id in tokenizer.vocab_trie.iteritems():
4450
# if token[0] in monitor.all_break_chars:
4551
# tokens_sort_key[token_id][0] = 0 # ".", ", a"
@@ -164,4 +170,4 @@ def convert_bytesrep_to_bytes(x: str) -> bytes:
164170
gen_tokens += new_all_tokens[all_tokens.shape[0]:].tolist()
165171
all_tokens = new_all_tokens
166172

167-
return gen_tokens, gen_text.decode()
173+
return gen_tokens, gen_text.decode()

0 commit comments

Comments
 (0)