File tree Expand file tree Collapse file tree 2 files changed +38
-0
lines changed Expand file tree Collapse file tree 2 files changed +38
-0
lines changed Original file line number Diff line number Diff line change @@ -111,6 +111,22 @@ def reset(self) -> None:
111111 self ._last_n_tokens = [0 ] * self ._n_ctx # n_ctx elements
112112 self ._n_past = 0
113113
114+ def tokenize (self , text :str ):
115+ """
116+ Returns a list of tokens for the text
117+ :param text: text to be tokenized
118+ :return: List of tokens
119+ """
120+ return pp .llama_tokenize (self ._ctx , text , True )
121+
122+ def detokenize (self , tokens :list ):
123+ """
124+ Returns a list of tokens for the text
125+ :param text: text to be tokenized
126+ :return: A string representing the text extracted from the tokens
127+ """
128+ return pp .llama_tokens_to_str (self ._ctx , tokens )
129+
114130 def generate (self ,
115131 prompt : str ,
116132 n_predict : Union [None , int ] = None ,
Original file line number Diff line number Diff line change @@ -87,7 +87,27 @@ std::vector<llama_token> llama_tokenize_wrapper(
8787// return tokens;
8888// }
8989
90+ std::string llama_tokens_to_str_wrapper (struct llama_context_wrapper * ctx_w, py::array_t <llama_token> tokens_array) {
91+ std::string result;
92+ struct llama_context * ctx = ctx_w->ptr ;
93+ bool all_tokens_valid = true ;
94+
95+ for (int i = 0 ; i < tokens_array.size (); i++) {
96+ llama_token token = tokens_array.at (i);
97+ if (token >= llama_n_vocab (ctx)) {
98+ all_tokens_valid = false ;
99+ break ;
100+ }
101+
102+ result += llama_token_to_str (ctx, token);
103+ }
90104
105+ if (all_tokens_valid) {
106+ return result;
107+ } else {
108+ return " " ;
109+ }
110+ }
91111
92112int llama_n_vocab_wrapper (struct llama_context_wrapper * ctx_w){
93113 struct llama_context * ctx = ctx_w->ptr ;
@@ -697,6 +717,8 @@ PYBIND11_MODULE(_pyllamacpp, m) {
697717 // @NOTE: to prevent implicit conversion of const char* to unicode on python side, leading to UnicodeDecodeError
698718 return py::bytes (llama_token_to_str_wrapper (ctx_w, token));
699719 });
720+ m.def (" llama_tokens_to_str" , &llama_tokens_to_str_wrapper);
721+
700722
701723 m.def (" llama_token_bos" , &llama_token_bos);
702724 m.def (" llama_token_eos" , &llama_token_eos);
You can’t perform that action at this time.
0 commit comments