Skip to content

Add Truncate pre-tokenizer #1783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
Sequence = pre_tokenizers.Sequence
Split = pre_tokenizers.Split
UnicodeScripts = pre_tokenizers.UnicodeScripts
Truncate = pre_tokenizers.Truncate
Whitespace = pre_tokenizers.Whitespace
WhitespaceSplit = pre_tokenizers.WhitespaceSplit
12 changes: 12 additions & 0 deletions bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,18 @@ class UnicodeScripts(PreTokenizer):
"""
pass

class Truncate(PreTokenizer):
"""Truncate pre-tokenizer"""

def __init__(self, max_length: int = 512, stride: int = 0, direction: str = "right"):
pass

def pre_tokenize(self, pretok):
pass

def pre_tokenize_str(self, sequence):
pass

class Whitespace(PreTokenizer):
"""
This pre-tokenizer simply splits using the following regex: `\w+|[^\w\s]+`
Expand Down
81 changes: 81 additions & 0 deletions bindings/python/src/pre_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use tk::pre_tokenizers::punctuation::Punctuation;
use tk::pre_tokenizers::split::Split;
use tk::pre_tokenizers::unicode_scripts::UnicodeScripts;
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use tk::pre_tokenizers::truncate::Truncate;
use tk::pre_tokenizers::PreTokenizerWrapper;
use tk::utils::truncation::{TruncationDirection, TruncationParams};
use tk::tokenizer::Offsets;
use tk::{PreTokenizedString, PreTokenizer};
use tokenizers as tk;
Expand Down Expand Up @@ -118,6 +120,12 @@ impl PyPreTokenizer {
.into_any()
.into()
}
PreTokenizerWrapper::Truncate(_) => {
Py::new(py, (PyTruncate {}, base))?
.into_pyobject(py)?
.into_any()
.into()
}
},
}
}
Expand Down Expand Up @@ -750,6 +758,78 @@ impl PyUnicodeScripts {
}
}

/// Truncate pre-tokenizer
///
/// This pre-tokenizer truncates text based on the provided parameters before tokenization.
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name = "Truncate")]
pub struct PyTruncate {}

#[pymethods]
impl PyTruncate {
#[getter]
fn get_max_length(self_: PyRef<Self>) -> usize {
getter!(self_, Truncate, params.max_length)
}

#[setter]
fn set_max_length(self_: PyRef<Self>, value: usize) {
setter!(self_, Truncate, params.max_length, value);
}

#[getter]
fn get_stride(self_: PyRef<Self>) -> usize {
getter!(self_, Truncate, params.stride)
}

#[setter]
fn set_stride(self_: PyRef<Self>, value: usize) {
setter!(self_, Truncate, params.stride, value);
}

#[getter]
fn get_direction(self_: PyRef<Self>) -> String {
getter!(self_, Truncate, params.direction.as_ref()).to_string()
}

#[setter]
fn set_direction(self_: PyRef<Self>, direction: &str) -> PyResult<()> {
let dir = match direction {
"left" => TruncationDirection::Left,
"right" => TruncationDirection::Right,
_ => {
return Err(exceptions::PyValueError::new_err(format!(
"Invalid truncation direction value : {}",
direction
)))
}
};
setter!(self_, Truncate, @params.direction, dir);
Ok(())
}

#[new]
#[pyo3(signature = (max_length=512, stride=0, direction="right"), text_signature = "(self, max_length=512, stride=0, direction='right')")]
fn new(max_length: usize, stride: usize, direction: &str) -> PyResult<(Self, PyPreTokenizer)> {
let dir = match direction {
"left" => TruncationDirection::Left,
"right" => TruncationDirection::Right,
_ => {
return Err(exceptions::PyValueError::new_err(format!(
"Invalid truncation direction value : {}",
direction
)))
}
};
let params = TruncationParams {
max_length,
stride,
direction: dir,
..Default::default()
};
Ok((PyTruncate {}, Truncate::new(params).into()))
}
}

#[derive(Clone)]
pub(crate) struct CustomPreTokenizer {
inner: PyObject,
Expand Down Expand Up @@ -926,6 +1006,7 @@ pub fn pre_tokenizers(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PySequence>()?;
m.add_class::<PyDigits>()?;
m.add_class::<PyUnicodeScripts>()?;
m.add_class::<PyTruncate>()?;
Ok(())
}

Expand Down
11 changes: 11 additions & 0 deletions bindings/python/tests/bindings/test_pre_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Sequence,
Split,
UnicodeScripts,
Truncate,
Whitespace,
WhitespaceSplit,
)
Expand Down Expand Up @@ -335,3 +336,13 @@ def pre_tokenize(self, pretok):
("Is", (15, 17)),
("Life", (17, 21)),
]


class TestTruncate:
def test_right(self):
pretok = Truncate(5)
assert pretok.pre_tokenize_str("Hello World") == [("Hello", (0, 5))]

def test_left(self):
pretok = Truncate(5, direction="left")
assert pretok.pre_tokenize_str("Hello World") == [("World", (6, 11))]
13 changes: 13 additions & 0 deletions tokenizers/src/pre_tokenizers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod sequence;
pub mod split;
pub mod unicode_scripts;
pub mod whitespace;
pub mod truncate;

use serde::{Deserialize, Deserializer, Serialize};

Expand All @@ -20,6 +21,7 @@ use crate::pre_tokenizers::punctuation::Punctuation;
use crate::pre_tokenizers::sequence::Sequence;
use crate::pre_tokenizers::split::Split;
use crate::pre_tokenizers::unicode_scripts::UnicodeScripts;
use crate::pre_tokenizers::truncate::Truncate;
use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use crate::{PreTokenizedString, PreTokenizer};

Expand All @@ -37,6 +39,7 @@ pub enum PreTokenizerWrapper {
WhitespaceSplit(WhitespaceSplit),
Digits(Digits),
UnicodeScripts(UnicodeScripts),
Truncate(Truncate),
}

impl PreTokenizer for PreTokenizerWrapper {
Expand All @@ -53,6 +56,7 @@ impl PreTokenizer for PreTokenizerWrapper {
Self::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized),
Self::Digits(wspt) => wspt.pre_tokenize(normalized),
Self::UnicodeScripts(us) => us.pre_tokenize(normalized),
Self::Truncate(t) => t.pre_tokenize(normalized),
}
}
}
Expand Down Expand Up @@ -82,6 +86,7 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper {
WhitespaceSplit,
Digits,
UnicodeScripts,
Truncate,
}

#[derive(Deserialize)]
Expand All @@ -105,6 +110,7 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper {
WhitespaceSplit(WhitespaceSplit),
Digits(Digits),
UnicodeScripts(UnicodeScripts),
Truncate(Truncate),
}

let helper = PreTokenizerHelper::deserialize(deserializer)?;
Expand Down Expand Up @@ -152,6 +158,9 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper {
EnumType::UnicodeScripts => PreTokenizerWrapper::UnicodeScripts(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Truncate => PreTokenizerWrapper::Truncate(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
}
}

Expand Down Expand Up @@ -187,6 +196,9 @@ impl<'de> Deserialize<'de> for PreTokenizerWrapper {
PreTokenizerUntagged::UnicodeScripts(unicode_scripts) => {
PreTokenizerWrapper::UnicodeScripts(unicode_scripts)
}
PreTokenizerUntagged::Truncate(truncate) => {
PreTokenizerWrapper::Truncate(truncate)
}
}
}
})
Expand All @@ -204,6 +216,7 @@ impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace);
impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);
impl_enum_from!(Digits, PreTokenizerWrapper, Digits);
impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
impl_enum_from!(Truncate, PreTokenizerWrapper, Truncate);

#[cfg(test)]
mod tests {
Expand Down
119 changes: 119 additions & 0 deletions tokenizers/src/pre_tokenizers/truncate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use serde::{Deserialize, Serialize};

use crate::tokenizer::{
normalizer::Range, OffsetReferential, OffsetType, PreTokenizedString, PreTokenizer, Result,
};
use crate::utils::macro_rules_attribute;
use crate::utils::truncation::{TruncationDirection, TruncationParams};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct Truncate {
#[serde(flatten)]
pub params: TruncationParams,

Check failure on line 13 in tokenizers/src/pre_tokenizers/truncate.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

binary operation `==` cannot be applied to type `TruncationParams`
}

impl Truncate {
pub fn new(params: TruncationParams) -> Self {
Self { params }
}
}

impl Default for Truncate {
fn default() -> Self {
Self {
params: TruncationParams::default(),
}
}
}

impl PreTokenizer for Truncate {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
let max_len = self.params.max_length;
let total_len: usize = pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.iter()
.map(|(s, _, _)| s.len())
.sum();
if total_len <= max_len {
return Ok(());
}

match self.params.direction {
TruncationDirection::Right => {
let mut remaining = max_len;
pretokenized.split(|_, mut s| {

Check warning on line 45 in tokenizers/src/pre_tokenizers/truncate.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.10)

variable does not need to be mutable
if remaining == 0 {
Ok(Vec::new())
} else {
let len = s.len();
if len <= remaining {
remaining -= len;
Ok(vec![s])
} else {
let slice = s
.slice(Range::Normalized(0..remaining))
.expect("NormalizedString bad slice");
remaining = 0;
Ok(vec![slice])
}
}
})
}
TruncationDirection::Left => {
let mut skip = total_len - max_len;
pretokenized.split(|_, mut s| {
if skip >= s.len() {
skip -= s.len();
Ok(Vec::new())
} else {
if skip > 0 {
let len = s.len();
s = s
.slice(Range::Normalized(skip..len))
.expect("NormalizedString bad slice");
skip = 0;
}
Ok(vec![s])
}
})
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{OffsetReferential, OffsetType};

#[test]
fn truncate_right() {
let params = TruncationParams { max_length: 4, ..Default::default() };
let pretok = Truncate::new(params);
let mut pretokenized = PreTokenizedString::from("Hello World");
pretok.pre_tokenize(&mut pretokenized).unwrap();
let parts: Vec<_> = pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, _o, _)| s)
.collect();
assert_eq!(parts.join(""), "Hell");
}

#[test]
fn truncate_left() {
let mut params = TruncationParams { max_length: 5, ..Default::default() };
params.direction = TruncationDirection::Left;
let pretok = Truncate::new(params);
let mut pretokenized = PreTokenizedString::from("Hello World");
pretok.pre_tokenize(&mut pretokenized).unwrap();
let parts: Vec<_> = pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, _o, _)| s)
.collect();
assert_eq!(parts.join(""), "World");
}
}

Loading