diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index fa937f13e..d23ec42c6 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -14,7 +14,7 @@ serde = { version = "1.0", features = ["rc", "derive"] } serde_json = "1.0" libc = "0.2" env_logger = "0.11" -pyo3 = { version = "0.24.2", features = ["abi3", "abi3-py39", "py-clone"] } +pyo3 = { version = "0.24.2", features = ["py-clone"] } numpy = "0.24" ndarray = "0.16" itertools = "0.12" diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 11e6e556c..fee94e3bd 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -97,6 +97,20 @@ class Encoding: """ pass + @property + def attention_mask_buffer(self): + """ + The attention mask as a buffer. + + This indicates to the LM which tokens should be attended to, and which should not. + This is especially important when batching sequences, where we need to applying + padding. + + Returns + :obj:`Buffer`: The attention mask + """ + pass + def char_to_token(self, char_pos, sequence_index=0): """ Get the token that contains the char at the given position in the input sequence. @@ -140,6 +154,19 @@ class Encoding: """ pass + @property + def ids_buffer(self): + """ + The generated IDs as a buffer. + + The IDs are the main input to a Language Model. They are the token indices, + the numerical representations that a LM understands. + + Returns + :obj:`Buffer`: The buffer of IDs + """ + pass + @staticmethod def merge(encodings, growing_offsets=True): """ @@ -180,6 +207,19 @@ class Encoding: """ pass + @property + def offsets_buffer(self): + """ + The generated type IDs as a buffer. + + Generally used for tasks like sequence classification or question answering, + these tokens let the LM know which input sequence corresponds to each tokens. + + Returns + :obj:`Buffer`: The buffer of type IDs + """ + pass + @property def overflowing(self): """ @@ -252,6 +292,18 @@ class Encoding: """ pass + @property + def special_tokens_mask_buffer(self): + """ + The special token mask as a buffer. + + This indicates which tokens are special tokens, and which are not. + + Returns + :obj:`Buffer`: The special tokens mask + """ + pass + def token_to_chars(self, token_index): """ Get the offsets of the token at the given index. @@ -346,6 +398,19 @@ class Encoding: """ pass + @property + def type_ids_buffer(self): + """ + The generated type IDs as a buffer. + + Generally used for tasks like sequence classification or question answering, + these tokens let the LM know which input sequence corresponds to each tokens. + + Returns + :obj:`Buffer`: The buffer of type IDs + """ + pass + @property def word_ids(self): """ diff --git a/bindings/python/src/encoding.rs b/bindings/python/src/encoding.rs index e157b8006..2d9af12af 100644 --- a/bindings/python/src/encoding.rs +++ b/bindings/python/src/encoding.rs @@ -119,6 +119,21 @@ impl PyEncoding { self.encoding.get_ids().to_vec() } + /// The generated IDs as a buffer. + /// + /// The IDs are the main input to a Language Model. They are the token indices, + /// the numerical representations that a LM understands. + /// + /// Returns + /// :obj:`Buffer`: The buffer of IDs + #[getter] + fn get_ids_buffer<'py>(&self, py: Python<'py>) -> PyResult> { + let data = self.encoding.get_ids().to_vec(); + let shape = vec![data.len() as isize]; + let buffer = PyUInt32Buffer::new(data, shape); + buffer.into_pyobject(py) + } + /// The generated tokens /// /// They are the string representation of the IDs. @@ -198,6 +213,21 @@ impl PyEncoding { self.encoding.get_type_ids().to_vec() } + /// The generated type IDs as a buffer. + /// + /// Generally used for tasks like sequence classification or question answering, + /// these tokens let the LM know which input sequence corresponds to each tokens. + /// + /// Returns + /// :obj:`Buffer`: The buffer of type ids + #[getter] + fn get_type_ids_buffer<'py>(&self, py: Python<'py>) -> PyResult> { + let data = self.encoding.get_type_ids().to_vec(); + let shape = vec![data.len() as isize]; + let buffer = PyUInt32Buffer::new(data, shape); + buffer.into_pyobject(py) + } + /// The offsets associated to each token /// /// These offsets let's you slice the input string, and thus retrieve the original @@ -210,6 +240,25 @@ impl PyEncoding { self.encoding.get_offsets().to_vec() } + /// The offsets associated to each token as a buffer. + /// + /// Generally used for tasks like sequence classification or question answering, + /// these tokens let the LM know which input sequence corresponds to each tokens. + /// + /// Returns + /// :obj:`Buffer`: The buffer of offsets + #[getter] + fn get_offsets_buffer<'py>(&self, py: Python<'py>) -> PyResult> { + let data = self.encoding.get_offsets().to_vec(); + let shape = vec![data.len() as isize, 2]; + let data: Vec = data + .into_iter() + .flat_map(|(start, end)| [start, end]) + .collect(); + let buffer = PyUSizeBuffer::new(data, shape); + buffer.into_pyobject(py) + } + /// The special token mask /// /// This indicates which tokens are special tokens, and which are not. @@ -221,6 +270,23 @@ impl PyEncoding { self.encoding.get_special_tokens_mask().to_vec() } + /// The special token mask as a buffer. + /// + /// This indicates which tokens are special tokens, and which are not. + /// + /// Returns + /// :obj:`Buffer`: The special tokens mask + #[getter] + fn get_special_tokens_mask_buffer<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let data = self.encoding.get_special_tokens_mask().to_vec(); + let shape = vec![data.len() as isize]; + let buffer = PyUInt32Buffer::new(data, shape); + buffer.into_pyobject(py) + } + /// The attention mask /// /// This indicates to the LM which tokens should be attended to, and which should not. @@ -234,6 +300,25 @@ impl PyEncoding { self.encoding.get_attention_mask().to_vec() } + /// The attention mask as a buffer. + /// + /// This indicates to the LM which tokens should be attended to, and which should not. + /// This is especially important when batching sequences, where we need to applying + /// padding. + /// + /// Returns + /// :obj:`Buffer`: The attention mask + #[getter] + fn get_attention_mask_buffer<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let data = self.encoding.get_attention_mask().to_vec(); + let shape = vec![data.len() as isize]; + let buffer = PyUInt32Buffer::new(data, shape); + buffer.into_pyobject(py) + } + /// A :obj:`List` of overflowing :class:`~tokenizers.Encoding` /// /// When using truncation, the :class:`~tokenizers.Tokenizer` takes care of splitting @@ -457,3 +542,125 @@ impl PyEncoding { Ok(()) } } + +macro_rules! define_py_buffer_protocol_type { + ($struct_name:ident, $data_type:ty) => { + #[pyclass] + struct $struct_name { + data: Vec<$data_type>, + shape: Vec, + strides: Vec, + } + + impl $struct_name { + fn new(data: Vec<$data_type>, shape: Vec) -> Self { + let mut strides: Vec = Vec::with_capacity(shape.len()); + let mut stride = std::mem::size_of::<$data_type>() as isize; + for dim in shape.iter().rev() { + strides.push(stride); + stride *= dim; + } + strides.reverse(); + + Self { + data, + shape, + strides, + } + } + } + + #[pymethods] + impl $struct_name { + // Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25 + unsafe fn __getbuffer__( + slf: pyo3::prelude::Bound<'_, Self>, + view: *mut pyo3::ffi::Py_buffer, + flags: std::os::raw::c_int, + ) -> pyo3::prelude::PyResult<()> { + if view.is_null() { + return Err(pyo3::exceptions::PyBufferError::new_err("View is null")); + } + if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE { + return Err(pyo3::exceptions::PyBufferError::new_err( + "Object is not writable", + )); + } + + let borrow = slf.borrow(); + let data = &borrow.data; + let shape = &borrow.shape; + let strides = &borrow.strides; + + (*view).obj = slf.clone().into_any().into_ptr(); + (*view).buf = data.as_ptr() as *mut std::os::raw::c_void; + (*view).len = (data.len() * std::mem::size_of::<$data_type>()) as isize; + (*view).readonly = 1; + (*view).itemsize = std::mem::size_of::<$data_type>() as isize; + (*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT { + let data_type_id = std::any::TypeId::of::<$data_type>(); + let format = { + if data_type_id == std::any::TypeId::of::() { + "?" + } else if data_type_id == std::any::TypeId::of::() { + "b" + } else if data_type_id == std::any::TypeId::of::() { + "B" + } else if data_type_id == std::any::TypeId::of::() { + "h" + } else if data_type_id == std::any::TypeId::of::() { + "H" + } else if data_type_id == std::any::TypeId::of::() { + "i" + } else if data_type_id == std::any::TypeId::of::() { + "I" + } else if data_type_id == std::any::TypeId::of::() { + "q" + } else if data_type_id == std::any::TypeId::of::() { + "Q" + } else if data_type_id == std::any::TypeId::of::() { + "n" + } else if data_type_id == std::any::TypeId::of::() { + "N" + } else if data_type_id == std::any::TypeId::of::() { + "f" + } else if data_type_id == std::any::TypeId::of::() { + "d" + } else { + return Err(pyo3::exceptions::PyBufferError::new_err( + "Unsupported data type", + )); + } + }; + let msg = std::ffi::CString::new(format).unwrap(); + msg.into_raw() + } else { + std::ptr::null_mut() + }; + (*view).ndim = shape.len() as i32; + (*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND { + shape.as_ptr() as *mut _ + } else { + std::ptr::null_mut() + }; + (*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES + { + strides.as_ptr() as *mut _ + } else { + std::ptr::null_mut() + }; + (*view).suboffsets = std::ptr::null_mut(); + (*view).internal = std::ptr::null_mut(); + + Ok(()) + } + + unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) { + std::mem::drop(std::ffi::CString::from_raw((*view).format)); + } + } + }; +} + +define_py_buffer_protocol_type!(PyUInt32Buffer, u32); +define_py_buffer_protocol_type!(PyUSizeBuffer, usize); diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index d50f283e7..19f21190c 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -592,6 +592,16 @@ def test_setting_to_none(self): tokenizer.pre_tokenizer = None assert tokenizer.pre_tokenizer == None + def test_encode_buffer_protocol(self): + tokenizer = Tokenizer(BPE()) + tokenizer.add_tokens(["my", "name", "is", "john"]) + output = tokenizer.encode("my name is john") + assert output.ids == memoryview(output.ids_buffer).tolist() + assert output.type_ids == memoryview(output.type_ids_buffer).tolist() + assert output.attention_mask == memoryview(output.attention_mask_buffer).tolist() + assert output.offsets == [tuple(offset) for offset in memoryview(output.offsets_buffer).tolist()] + assert output.special_tokens_mask == memoryview(output.special_tokens_mask_buffer).tolist() + class TestTokenizerRepr: def test_repr(self):