Skip to content

Expose Encoding attributes via the buffer protocol interface #1789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ serde = { version = "1.0", features = ["rc", "derive"] }
serde_json = "1.0"
libc = "0.2"
env_logger = "0.11"
pyo3 = { version = "0.24.2", features = ["abi3", "abi3-py39", "py-clone"] }
pyo3 = { version = "0.24.2", features = ["py-clone"] }
numpy = "0.24"
ndarray = "0.16"
itertools = "0.12"
Expand Down
65 changes: 65 additions & 0 deletions bindings/python/py_src/tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ class Encoding:
"""
pass

@property
def attention_mask_buffer(self):
"""
The attention mask as a buffer.

This indicates to the LM which tokens should be attended to, and which should not.
This is especially important when batching sequences, where we need to applying
padding.

Returns
:obj:`Buffer`: The attention mask
"""
pass

def char_to_token(self, char_pos, sequence_index=0):
"""
Get the token that contains the char at the given position in the input sequence.
Expand Down Expand Up @@ -140,6 +154,19 @@ class Encoding:
"""
pass

@property
def ids_buffer(self):
"""
The generated IDs as a buffer.

The IDs are the main input to a Language Model. They are the token indices,
the numerical representations that a LM understands.

Returns
:obj:`Buffer`: The buffer of IDs
"""
pass

@staticmethod
def merge(encodings, growing_offsets=True):
"""
Expand Down Expand Up @@ -180,6 +207,19 @@ class Encoding:
"""
pass

@property
def offsets_buffer(self):
"""
The generated type IDs as a buffer.

Generally used for tasks like sequence classification or question answering,
these tokens let the LM know which input sequence corresponds to each tokens.

Returns
:obj:`Buffer`: The buffer of type IDs
"""
pass

@property
def overflowing(self):
"""
Expand Down Expand Up @@ -252,6 +292,18 @@ class Encoding:
"""
pass

@property
def special_tokens_mask_buffer(self):
"""
The special token mask as a buffer.

This indicates which tokens are special tokens, and which are not.

Returns
:obj:`Buffer`: The special tokens mask
"""
pass

def token_to_chars(self, token_index):
"""
Get the offsets of the token at the given index.
Expand Down Expand Up @@ -346,6 +398,19 @@ class Encoding:
"""
pass

@property
def type_ids_buffer(self):
"""
The generated type IDs as a buffer.

Generally used for tasks like sequence classification or question answering,
these tokens let the LM know which input sequence corresponds to each tokens.

Returns
:obj:`Buffer`: The buffer of type IDs
"""
pass

@property
def word_ids(self):
"""
Expand Down
207 changes: 207 additions & 0 deletions bindings/python/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,21 @@ impl PyEncoding {
self.encoding.get_ids().to_vec()
}

/// The generated IDs as a buffer.
///
/// The IDs are the main input to a Language Model. They are the token indices,
/// the numerical representations that a LM understands.
///
/// Returns
/// :obj:`Buffer`: The buffer of IDs
#[getter]
fn get_ids_buffer<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyUInt32Buffer>> {
let data = self.encoding.get_ids().to_vec();
let shape = vec![data.len() as isize];
let buffer = PyUInt32Buffer::new(data, shape);
buffer.into_pyobject(py)
}

/// The generated tokens
///
/// They are the string representation of the IDs.
Expand Down Expand Up @@ -198,6 +213,21 @@ impl PyEncoding {
self.encoding.get_type_ids().to_vec()
}

/// The generated type IDs as a buffer.
///
/// Generally used for tasks like sequence classification or question answering,
/// these tokens let the LM know which input sequence corresponds to each tokens.
///
/// Returns
/// :obj:`Buffer`: The buffer of type ids
#[getter]
fn get_type_ids_buffer<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyUInt32Buffer>> {
let data = self.encoding.get_type_ids().to_vec();
let shape = vec![data.len() as isize];
let buffer = PyUInt32Buffer::new(data, shape);
buffer.into_pyobject(py)
}

/// The offsets associated to each token
///
/// These offsets let's you slice the input string, and thus retrieve the original
Expand All @@ -210,6 +240,25 @@ impl PyEncoding {
self.encoding.get_offsets().to_vec()
}

/// The offsets associated to each token as a buffer.
///
/// Generally used for tasks like sequence classification or question answering,
/// these tokens let the LM know which input sequence corresponds to each tokens.
///
/// Returns
/// :obj:`Buffer`: The buffer of offsets
#[getter]
fn get_offsets_buffer<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyUSizeBuffer>> {
let data = self.encoding.get_offsets().to_vec();
let shape = vec![data.len() as isize, 2];
let data: Vec<usize> = data
.into_iter()
.flat_map(|(start, end)| [start, end])
.collect();
let buffer = PyUSizeBuffer::new(data, shape);
buffer.into_pyobject(py)
}

/// The special token mask
///
/// This indicates which tokens are special tokens, and which are not.
Expand All @@ -221,6 +270,23 @@ impl PyEncoding {
self.encoding.get_special_tokens_mask().to_vec()
}

/// The special token mask as a buffer.
///
/// This indicates which tokens are special tokens, and which are not.
///
/// Returns
/// :obj:`Buffer`: The special tokens mask
#[getter]
fn get_special_tokens_mask_buffer<'py>(
&self,
py: Python<'py>,
) -> PyResult<Bound<'py, PyUInt32Buffer>> {
let data = self.encoding.get_special_tokens_mask().to_vec();
let shape = vec![data.len() as isize];
let buffer = PyUInt32Buffer::new(data, shape);
buffer.into_pyobject(py)
}

/// The attention mask
///
/// This indicates to the LM which tokens should be attended to, and which should not.
Expand All @@ -234,6 +300,25 @@ impl PyEncoding {
self.encoding.get_attention_mask().to_vec()
}

/// The attention mask as a buffer.
///
/// This indicates to the LM which tokens should be attended to, and which should not.
/// This is especially important when batching sequences, where we need to applying
/// padding.
///
/// Returns
/// :obj:`Buffer`: The attention mask
#[getter]
fn get_attention_mask_buffer<'py>(
&self,
py: Python<'py>,
) -> PyResult<Bound<'py, PyUInt32Buffer>> {
let data = self.encoding.get_attention_mask().to_vec();
let shape = vec![data.len() as isize];
let buffer = PyUInt32Buffer::new(data, shape);
buffer.into_pyobject(py)
}

/// A :obj:`List` of overflowing :class:`~tokenizers.Encoding`
///
/// When using truncation, the :class:`~tokenizers.Tokenizer` takes care of splitting
Expand Down Expand Up @@ -457,3 +542,125 @@ impl PyEncoding {
Ok(())
}
}

macro_rules! define_py_buffer_protocol_type {
($struct_name:ident, $data_type:ty) => {
#[pyclass]
struct $struct_name {
data: Vec<$data_type>,
shape: Vec<isize>,
strides: Vec<isize>,
}

impl $struct_name {
fn new(data: Vec<$data_type>, shape: Vec<isize>) -> Self {
let mut strides: Vec<isize> = Vec::with_capacity(shape.len());
let mut stride = std::mem::size_of::<$data_type>() as isize;
for dim in shape.iter().rev() {
strides.push(stride);
stride *= dim;
}
strides.reverse();

Self {
data,
shape,
strides,
}
}
}

#[pymethods]
impl $struct_name {
// Based on https://github.com/PyO3/pyo3/blob/v0.22.2/tests/test_buffer_protocol.rs#L25
unsafe fn __getbuffer__(
slf: pyo3::prelude::Bound<'_, Self>,
view: *mut pyo3::ffi::Py_buffer,
flags: std::os::raw::c_int,
) -> pyo3::prelude::PyResult<()> {
if view.is_null() {
return Err(pyo3::exceptions::PyBufferError::new_err("View is null"));
}
if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE {
return Err(pyo3::exceptions::PyBufferError::new_err(
"Object is not writable",
));
}

let borrow = slf.borrow();
let data = &borrow.data;
let shape = &borrow.shape;
let strides = &borrow.strides;

(*view).obj = slf.clone().into_any().into_ptr();
(*view).buf = data.as_ptr() as *mut std::os::raw::c_void;
(*view).len = (data.len() * std::mem::size_of::<$data_type>()) as isize;
(*view).readonly = 1;
(*view).itemsize = std::mem::size_of::<$data_type>() as isize;
(*view).format = if (flags & pyo3::ffi::PyBUF_FORMAT) == pyo3::ffi::PyBUF_FORMAT {
let data_type_id = std::any::TypeId::of::<$data_type>();
let format = {
if data_type_id == std::any::TypeId::of::<bool>() {
"?"
} else if data_type_id == std::any::TypeId::of::<i8>() {
"b"
} else if data_type_id == std::any::TypeId::of::<u8>() {
"B"
} else if data_type_id == std::any::TypeId::of::<i16>() {
"h"
} else if data_type_id == std::any::TypeId::of::<u16>() {
"H"
} else if data_type_id == std::any::TypeId::of::<i32>() {
"i"
} else if data_type_id == std::any::TypeId::of::<u32>() {
"I"
} else if data_type_id == std::any::TypeId::of::<i64>() {
"q"
} else if data_type_id == std::any::TypeId::of::<u64>() {
"Q"
} else if data_type_id == std::any::TypeId::of::<isize>() {
"n"
} else if data_type_id == std::any::TypeId::of::<usize>() {
"N"
} else if data_type_id == std::any::TypeId::of::<f32>() {
"f"
} else if data_type_id == std::any::TypeId::of::<f64>() {
"d"
} else {
return Err(pyo3::exceptions::PyBufferError::new_err(
"Unsupported data type",
));
}
};
let msg = std::ffi::CString::new(format).unwrap();
msg.into_raw()
} else {
std::ptr::null_mut()
};
(*view).ndim = shape.len() as i32;
(*view).shape = if (flags & pyo3::ffi::PyBUF_ND) == pyo3::ffi::PyBUF_ND {
shape.as_ptr() as *mut _
} else {
std::ptr::null_mut()
};
(*view).strides = if (flags & pyo3::ffi::PyBUF_STRIDES) == pyo3::ffi::PyBUF_STRIDES
{
strides.as_ptr() as *mut _
} else {
std::ptr::null_mut()
};
(*view).suboffsets = std::ptr::null_mut();
(*view).internal = std::ptr::null_mut();

Ok(())
}

unsafe fn __releasebuffer__(&self, view: *mut pyo3::ffi::Py_buffer) {
std::mem::drop(std::ffi::CString::from_raw((*view).format));
}
}
};
}

define_py_buffer_protocol_type!(PyUInt32Buffer, u32);
define_py_buffer_protocol_type!(PyUSizeBuffer, usize);
10 changes: 10 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,16 @@ def test_setting_to_none(self):
tokenizer.pre_tokenizer = None
assert tokenizer.pre_tokenizer == None

def test_encode_buffer_protocol(self):
tokenizer = Tokenizer(BPE())
tokenizer.add_tokens(["my", "name", "is", "john"])
output = tokenizer.encode("my name is john")
assert output.ids == memoryview(output.ids_buffer).tolist()
assert output.type_ids == memoryview(output.type_ids_buffer).tolist()
assert output.attention_mask == memoryview(output.attention_mask_buffer).tolist()
assert output.offsets == [tuple(offset) for offset in memoryview(output.offsets_buffer).tolist()]
assert output.special_tokens_mask == memoryview(output.special_tokens_mask_buffer).tolist()


class TestTokenizerRepr:
def test_repr(self):
Expand Down