Skip to content

Commit dc70fd6

Browse files
authored
Merge pull request #6 from ParisNeo/main
Added tokenize and untokenize functions
2 parents fd33930 + ccf3fc9 commit dc70fd6

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

pyllamacpp/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,22 @@ def reset(self) -> None:
111111
self._last_n_tokens = [0] * self._n_ctx # n_ctx elements
112112
self._n_past = 0
113113

114+
def tokenize(self, text:str):
115+
"""
116+
Returns a list of tokens for the text
117+
:param text: text to be tokenized
118+
:return: List of tokens
119+
"""
120+
return pp.llama_tokenize(self._ctx, text, True)
121+
122+
def detokenize(self, tokens:list):
123+
"""
124+
Returns a list of tokens for the text
125+
:param text: text to be tokenized
126+
:return: A string representing the text extracted from the tokens
127+
"""
128+
return pp.llama_tokens_to_str(self._ctx, tokens)
129+
114130
def generate(self,
115131
prompt: str,
116132
n_predict: Union[None, int] = None,

src/main.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,27 @@ std::vector<llama_token> llama_tokenize_wrapper(
8787
// return tokens;
8888
//}
8989

90+
std::string llama_tokens_to_str_wrapper(struct llama_context_wrapper* ctx_w, py::array_t<llama_token> tokens_array) {
91+
std::string result;
92+
struct llama_context * ctx = ctx_w->ptr;
93+
bool all_tokens_valid = true;
94+
95+
for (int i = 0; i < tokens_array.size(); i++) {
96+
llama_token token = tokens_array.at(i);
97+
if (token >= llama_n_vocab(ctx)) {
98+
all_tokens_valid = false;
99+
break;
100+
}
101+
102+
result += llama_token_to_str(ctx, token);
103+
}
90104

105+
if (all_tokens_valid) {
106+
return result;
107+
} else {
108+
return "";
109+
}
110+
}
91111

92112
int llama_n_vocab_wrapper(struct llama_context_wrapper * ctx_w){
93113
struct llama_context * ctx = ctx_w->ptr;
@@ -697,6 +717,8 @@ PYBIND11_MODULE(_pyllamacpp, m) {
697717
//@NOTE: to prevent implicit conversion of const char* to unicode on python side, leading to UnicodeDecodeError
698718
return py::bytes(llama_token_to_str_wrapper(ctx_w, token));
699719
});
720+
m.def("llama_tokens_to_str", &llama_tokens_to_str_wrapper);
721+
700722

701723
m.def("llama_token_bos", &llama_token_bos);
702724
m.def("llama_token_eos", &llama_token_eos);

0 commit comments

Comments
 (0)