diff --git a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py index 48277f0d2..b10784728 100644 --- a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py +++ b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.py @@ -11,5 +11,6 @@ Sequence = pre_tokenizers.Sequence Split = pre_tokenizers.Split UnicodeScripts = pre_tokenizers.UnicodeScripts +Truncate = pre_tokenizers.Truncate Whitespace = pre_tokenizers.Whitespace WhitespaceSplit = pre_tokenizers.WhitespaceSplit diff --git a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi index 6f31ff3a2..83a7b77fc 100644 --- a/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi @@ -521,6 +521,18 @@ class UnicodeScripts(PreTokenizer): """ pass +class Truncate(PreTokenizer): + """Truncate pre-tokenizer""" + + def __init__(self, max_length: int = 512, stride: int = 0, direction: str = "right"): + pass + + def pre_tokenize(self, pretok): + pass + + def pre_tokenize_str(self, sequence): + pass + class Whitespace(PreTokenizer): """ This pre-tokenizer simply splits using the following regex: `\w+|[^\w\s]+` diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 8140ade1d..51aca4ca0 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -17,7 +17,9 @@ use tk::pre_tokenizers::punctuation::Punctuation; use tk::pre_tokenizers::split::Split; use tk::pre_tokenizers::unicode_scripts::UnicodeScripts; use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; +use tk::pre_tokenizers::truncate::Truncate; use tk::pre_tokenizers::PreTokenizerWrapper; +use tk::utils::truncation::{TruncationDirection, TruncationParams}; use tk::tokenizer::Offsets; use tk::{PreTokenizedString, PreTokenizer}; use tokenizers as tk; @@ -118,6 +120,12 @@ impl PyPreTokenizer { .into_any() .into() } + PreTokenizerWrapper::Truncate(_) => { + Py::new(py, (PyTruncate {}, base))? + .into_pyobject(py)? + .into_any() + .into() + } }, } } @@ -750,6 +758,78 @@ impl PyUnicodeScripts { } } +/// Truncate pre-tokenizer +/// +/// This pre-tokenizer truncates text based on the provided parameters before tokenization. +#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Truncate")] +pub struct PyTruncate {} + +#[pymethods] +impl PyTruncate { + #[getter] + fn get_max_length(self_: PyRef) -> usize { + getter!(self_, Truncate, params.max_length) + } + + #[setter] + fn set_max_length(self_: PyRef, value: usize) { + setter!(self_, Truncate, params.max_length, value); + } + + #[getter] + fn get_stride(self_: PyRef) -> usize { + getter!(self_, Truncate, params.stride) + } + + #[setter] + fn set_stride(self_: PyRef, value: usize) { + setter!(self_, Truncate, params.stride, value); + } + + #[getter] + fn get_direction(self_: PyRef) -> String { + getter!(self_, Truncate, params.direction.as_ref()).to_string() + } + + #[setter] + fn set_direction(self_: PyRef, direction: &str) -> PyResult<()> { + let dir = match direction { + "left" => TruncationDirection::Left, + "right" => TruncationDirection::Right, + _ => { + return Err(exceptions::PyValueError::new_err(format!( + "Invalid truncation direction value : {}", + direction + ))) + } + }; + setter!(self_, Truncate, @params.direction, dir); + Ok(()) + } + + #[new] + #[pyo3(signature = (max_length=512, stride=0, direction="right"), text_signature = "(self, max_length=512, stride=0, direction='right')")] + fn new(max_length: usize, stride: usize, direction: &str) -> PyResult<(Self, PyPreTokenizer)> { + let dir = match direction { + "left" => TruncationDirection::Left, + "right" => TruncationDirection::Right, + _ => { + return Err(exceptions::PyValueError::new_err(format!( + "Invalid truncation direction value : {}", + direction + ))) + } + }; + let params = TruncationParams { + max_length, + stride, + direction: dir, + ..Default::default() + }; + Ok((PyTruncate {}, Truncate::new(params).into())) + } +} + #[derive(Clone)] pub(crate) struct CustomPreTokenizer { inner: PyObject, @@ -926,6 +1006,7 @@ pub fn pre_tokenizers(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/tests/bindings/test_pre_tokenizers.py b/bindings/python/tests/bindings/test_pre_tokenizers.py index 3611930ae..60d63dcc9 100644 --- a/bindings/python/tests/bindings/test_pre_tokenizers.py +++ b/bindings/python/tests/bindings/test_pre_tokenizers.py @@ -14,6 +14,7 @@ Sequence, Split, UnicodeScripts, + Truncate, Whitespace, WhitespaceSplit, ) @@ -335,3 +336,13 @@ def pre_tokenize(self, pretok): ("Is", (15, 17)), ("Life", (17, 21)), ] + + +class TestTruncate: + def test_right(self): + pretok = Truncate(5) + assert pretok.pre_tokenize_str("Hello World") == [("Hello", (0, 5))] + + def test_left(self): + pretok = Truncate(5, direction="left") + assert pretok.pre_tokenize_str("Hello World") == [("World", (6, 11))] diff --git a/tokenizers/src/pre_tokenizers/mod.rs b/tokenizers/src/pre_tokenizers/mod.rs index 6195d170b..4853ff2de 100644 --- a/tokenizers/src/pre_tokenizers/mod.rs +++ b/tokenizers/src/pre_tokenizers/mod.rs @@ -8,6 +8,7 @@ pub mod sequence; pub mod split; pub mod unicode_scripts; pub mod whitespace; +pub mod truncate; use serde::{Deserialize, Deserializer, Serialize}; @@ -20,6 +21,7 @@ use crate::pre_tokenizers::punctuation::Punctuation; use crate::pre_tokenizers::sequence::Sequence; use crate::pre_tokenizers::split::Split; use crate::pre_tokenizers::unicode_scripts::UnicodeScripts; +use crate::pre_tokenizers::truncate::Truncate; use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use crate::{PreTokenizedString, PreTokenizer}; @@ -37,6 +39,7 @@ pub enum PreTokenizerWrapper { WhitespaceSplit(WhitespaceSplit), Digits(Digits), UnicodeScripts(UnicodeScripts), + Truncate(Truncate), } impl PreTokenizer for PreTokenizerWrapper { @@ -53,6 +56,7 @@ impl PreTokenizer for PreTokenizerWrapper { Self::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized), Self::Digits(wspt) => wspt.pre_tokenize(normalized), Self::UnicodeScripts(us) => us.pre_tokenize(normalized), + Self::Truncate(t) => t.pre_tokenize(normalized), } } } @@ -82,6 +86,7 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { WhitespaceSplit, Digits, UnicodeScripts, + Truncate, } #[derive(Deserialize)] @@ -105,6 +110,7 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { WhitespaceSplit(WhitespaceSplit), Digits(Digits), UnicodeScripts(UnicodeScripts), + Truncate(Truncate), } let helper = PreTokenizerHelper::deserialize(deserializer)?; @@ -152,6 +158,9 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { EnumType::UnicodeScripts => PreTokenizerWrapper::UnicodeScripts( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), + EnumType::Truncate => PreTokenizerWrapper::Truncate( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), } } @@ -187,6 +196,9 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper { PreTokenizerUntagged::UnicodeScripts(unicode_scripts) => { PreTokenizerWrapper::UnicodeScripts(unicode_scripts) } + PreTokenizerUntagged::Truncate(truncate) => { + PreTokenizerWrapper::Truncate(truncate) + } } } }) @@ -204,6 +216,7 @@ impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace); impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit); impl_enum_from!(Digits, PreTokenizerWrapper, Digits); impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts); +impl_enum_from!(Truncate, PreTokenizerWrapper, Truncate); #[cfg(test)] mod tests { diff --git a/tokenizers/src/pre_tokenizers/truncate.rs b/tokenizers/src/pre_tokenizers/truncate.rs new file mode 100644 index 000000000..6bac0371b --- /dev/null +++ b/tokenizers/src/pre_tokenizers/truncate.rs @@ -0,0 +1,119 @@ +use serde::{Deserialize, Serialize}; + +use crate::tokenizer::{ + normalizer::Range, OffsetReferential, OffsetType, PreTokenizedString, PreTokenizer, Result, +}; +use crate::utils::macro_rules_attribute; +use crate::utils::truncation::{TruncationDirection, TruncationParams}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[macro_rules_attribute(impl_serde_type!)] +pub struct Truncate { + #[serde(flatten)] + pub params: TruncationParams, +} + +impl Truncate { + pub fn new(params: TruncationParams) -> Self { + Self { params } + } +} + +impl Default for Truncate { + fn default() -> Self { + Self { + params: TruncationParams::default(), + } + } +} + +impl PreTokenizer for Truncate { + fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { + let max_len = self.params.max_length; + let total_len: usize = pretokenized + .get_splits(OffsetReferential::Normalized, OffsetType::Byte) + .iter() + .map(|(s, _, _)| s.len()) + .sum(); + if total_len <= max_len { + return Ok(()); + } + + match self.params.direction { + TruncationDirection::Right => { + let mut remaining = max_len; + pretokenized.split(|_, mut s| { + if remaining == 0 { + Ok(Vec::new()) + } else { + let len = s.len(); + if len <= remaining { + remaining -= len; + Ok(vec![s]) + } else { + let slice = s + .slice(Range::Normalized(0..remaining)) + .expect("NormalizedString bad slice"); + remaining = 0; + Ok(vec![slice]) + } + } + }) + } + TruncationDirection::Left => { + let mut skip = total_len - max_len; + pretokenized.split(|_, mut s| { + if skip >= s.len() { + skip -= s.len(); + Ok(Vec::new()) + } else { + if skip > 0 { + let len = s.len(); + s = s + .slice(Range::Normalized(skip..len)) + .expect("NormalizedString bad slice"); + skip = 0; + } + Ok(vec![s]) + } + }) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{OffsetReferential, OffsetType}; + + #[test] + fn truncate_right() { + let params = TruncationParams { max_length: 4, ..Default::default() }; + let pretok = Truncate::new(params); + let mut pretokenized = PreTokenizedString::from("Hello World"); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + let parts: Vec<_> = pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, _o, _)| s) + .collect(); + assert_eq!(parts.join(""), "Hell"); + } + + #[test] + fn truncate_left() { + let mut params = TruncationParams { max_length: 5, ..Default::default() }; + params.direction = TruncationDirection::Left; + let pretok = Truncate::new(params); + let mut pretokenized = PreTokenizedString::from("Hello World"); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + let parts: Vec<_> = pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, _o, _)| s) + .collect(); + assert_eq!(parts.join(""), "World"); + } +} +