diff --git a/bindings/python/benches/test_deserialize.py b/bindings/python/benches/test_deserialize.py new file mode 100644 index 000000000..d98d6d35c --- /dev/null +++ b/bindings/python/benches/test_deserialize.py @@ -0,0 +1,62 @@ +import json +import timeit +from tokenizers import Tokenizer, AddedToken +from tokenizers.models import WordLevel +from tokenizers.normalizers import ( + ByteLevel, Lowercase, NFC, NFD, NFKC, NFKD, Nmt, Strip, Replace, Prepend, BertNormalizer +) +import pytest + +def build_tokenizer_json(size, normalizer=None, special_tokens=True): + # Build vocab and WordLevel model + vocab = {"a": 0} + model = WordLevel(vocab=vocab, unk_token="[UNK]") + # Add normalizer if specified + tokenizer = Tokenizer(model) + if normalizer: + tokenizer.normalizer = normalizer + tokens = [AddedToken(f"tok{i}", special=special_tokens) for i in range(size)] + tokenizer.add_tokens(tokens) + # Return serialized tokenizer JSON + return tokenizer.to_str() + + +normalizer_factories = { + "none": None, + "byte_level": ByteLevel, + "lowercase": Lowercase, + "nfc": NFC, + "nfd": NFD, + "nfkc": NFKC, + "nfkd": NFKD, + "nmt": Nmt, + "strip": lambda: Strip(strip_left=True, strip_right=True), + "replace": lambda: Replace("a", "b"), + "prepend": lambda: Prepend("pre_"), + "bert": BertNormalizer +} + + + + +@pytest.mark.parametrize("special_tokens", [True, False]) +@pytest.mark.parametrize("norm_name,norm_factory", normalizer_factories.items()) +@pytest.mark.parametrize("size", [10_000, 100_000]) +def test_tokenizer_deserialization(benchmark, size, special_tokens, norm_name, norm_factory): + """Benchmark Tokenizer.from_str deserialization with different vocab sizes and normalizers.""" + normalizer = norm_factory() if norm_factory else None + tok_json = build_tokenizer_json(size, normalizer, special_tokens) + + def deserialize(): + tok = Tokenizer.from_str(tok_json) + _ = tok + + benchmark.group = f"deserialize_{size}_{'special' if special_tokens else 'non_special'}" + benchmark.name = f"norm_{norm_name}" + benchmark(deserialize) + +# some example usage +# pytest benches/test_deserialize.py --benchmark-enable +# pytest test_deserialization_benchmark.py --benchmark-save=baseline +# pytest test_deserialization_benchmark.py --benchmark-compare=baseline +# pytest test_deserialization_benchmark.py --benchmark-compare=baseline --benchmark-save=baseline2 \ No newline at end of file diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index db56865d2..25dc7f2c8 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -41,6 +41,10 @@ name = "llama3" required-features = ["http"] harness = false +[[bench]] +name = "added_vocab_deserialize" +harness = false + [dependencies] rand = "0.8" onig = { version = "6.4", default-features = false, optional = true } diff --git a/tokenizers/benches/added_vocab_deserialize.rs b/tokenizers/benches/added_vocab_deserialize.rs new file mode 100644 index 000000000..5693919c4 --- /dev/null +++ b/tokenizers/benches/added_vocab_deserialize.rs @@ -0,0 +1,83 @@ +#[macro_use] +extern crate criterion; +use std::{process::exit, str::FromStr}; +use criterion::{black_box, Criterion}; +use std::collections::HashMap; +use tokenizers::{models::wordlevel::WordLevel, normalizers::*, AddedToken, Normalizer, Tokenizer}; + +fn serialized_tokenizer>( + size: usize, + normalizer: Option, + special_tokens: bool, +) -> String { + let mut vocab = HashMap::new(); + vocab.insert("a".to_string(), 0); + let model = WordLevel::builder() + .vocab(vocab) + .unk_token("[UNK]".into()) + .build() + .unwrap(); + + let mut tokenizer = Tokenizer::new(model); + let tokens: Vec<_> = (0..size) + .map(|i| AddedToken::from(format!("tok{i}"), special_tokens)) + .collect(); + tokenizer.add_tokens(&tokens); + + if let Some(norm) = normalizer { + tokenizer.with_normalizer(Some(norm)); + } + + serde_json::to_string(&tokenizer).unwrap() +} + +fn bench_deserialize(c: &mut Criterion) { + let normalizers: Vec<(&str, Option NormalizerWrapper>)> = vec![ + ("none", None), + ("byte_level", Some(|| ByteLevel::default().into())), + ("lowercase", Some(|| Lowercase.into())), + ("nfc", Some(|| NFC.into())), + ("nfd", Some(|| NFD.into())), + ("nfkc", Some(|| NFKC.into())), + ("nfkd", Some(|| NFKD.into())), + ("nmt", Some(|| Nmt.into())), + ("strip", Some(|| Strip::new(true, true).into())), + ("replace", Some(|| Replace::new("a", "b").unwrap().into())), + ("prepend", Some(|| Prepend::new("pre_".to_string()).into())), + ("bert", Some(|| BertNormalizer::default().into())), + ]; + + for &size in &[10_000usize, 100_000, 400_000] { + for (norm_name, maybe_factory) in &normalizers { + let label = format!("special tokens deserialize_added_vocab_{}_norm_{}", size, norm_name); + + let json = match maybe_factory { + Some(factory) => serialized_tokenizer(size, Some(factory()), true), + None => serialized_tokenizer::(size, None, true), + }; + c.bench_function(&label, |b| { + b.iter(|| { + let tok: Tokenizer = black_box(Tokenizer::from_str(&json).unwrap()); + black_box(tok); + }) + }); + + let label = format!("non special deserialize_added_vocab_{}_norm_{}", size, norm_name); + + let json = match maybe_factory { + Some(factory) => serialized_tokenizer(size, Some(factory()), false), + None => serialized_tokenizer::(size, None, false), + }; + c.bench_function(&label, |b| { + b.iter(|| { + let tok: Tokenizer = black_box(Tokenizer::from_str(&json).unwrap()); + black_box(tok); + }) + }); + } + exit(0); + } +} + +criterion_group!(benches, bench_deserialize); +criterion_main!(benches); diff --git a/tokenizers/src/normalizers/byte_level.rs b/tokenizers/src/normalizers/byte_level.rs index ae8fecfb6..ae89d74f2 100644 --- a/tokenizers/src/normalizers/byte_level.rs +++ b/tokenizers/src/normalizers/byte_level.rs @@ -4,6 +4,8 @@ use crate::utils::macro_rules_attribute; use std::collections::{HashMap, HashSet}; use std::sync::LazyLock; + + #[derive(Clone, Debug)] #[macro_rules_attribute(impl_serde_type!)] pub struct ByteLevel; @@ -46,7 +48,7 @@ impl Normalizer for ByteLevel { } normalized.transform(transformations, 0); } - Ok(()) + Ok(()) } } diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index a93d53289..8bd266284 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -272,35 +272,38 @@ impl AddedVocabulary { } } - // Then we delegate to `add_tokens`, that will take care of refreshing added tokens too. let mut ignored = 0; + + let mut existing: HashSet = self.added_tokens_map_r.values().cloned().collect(); + let mut next_id = self.added_tokens_map_r.keys().copied().max().map_or( + model.get_vocab_size() as u32, + |max| { + if max >= model.get_vocab_size() as u32 || model.get_vocab_size() == 0 { + max + 1 + } else { + model.get_vocab_size() as u32 + } + }, + ); + for token in tokens { - if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token) - { + if token.content.is_empty() || existing.contains(token) { ignored += 1; continue; } - // If a token is already part of the vocabulary, we mark it as added + let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) { new_id } else { - self.added_tokens_map.values().cloned().max().map_or( - model.get_vocab_size() as u32, - |max| { - if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 { - max + 1 - } else { - model.get_vocab_size() as u32 - } - }, - ) + let id = next_id; + next_id += 1; + id }; - // Make sure we modify the previous entry + self.added_tokens_map .entry(token.content.clone()) .and_modify(|old_id| *old_id = new_id) - .or_insert_with(|| new_id); - // Update the current revert operation + .or_insert(new_id); self.added_tokens_map_r .entry(new_id) .and_modify(|t| *t = token.clone()) @@ -311,6 +314,7 @@ impl AddedVocabulary { if !self.special_tokens_set.contains(&token.content) { self.added_tokens.push(token.clone()); } + existing.insert(token.clone()); } self.refresh_added_tokens(model, normalizer); @@ -338,29 +342,33 @@ impl AddedVocabulary { }) .partition(|(token, _)| token.normalized); + // Build non-normalized trie let (tokens, ids): (Vec<&AddedToken>, Vec) = non_normalized.into_iter().unzip(); let trie = AhoCorasickBuilder::new() .match_kind(MatchKind::LeftmostLongest) .build(tokens.iter().map(|token| &token.content)) - .expect("Failed to build tried when refreshing tokens"); - self.split_trie = (trie, ids); - - let (ntokens, nids): (Vec<&AddedToken>, Vec) = normalized.into_iter().unzip(); - let patterns: Vec<_> = ntokens - .iter() - .map(|token| { - let mut content = NormalizedString::from(token.content.as_ref()); - if let Some(n) = normalizer { + .expect("Failed to build trie when refreshing tokens"); + self.split_trie = (trie.clone(), ids.clone()); + + // Build normalized trie + if let Some(n) = normalizer { + let (ntokens, nids): (Vec<&AddedToken>, Vec) = normalized.into_iter().unzip(); + let patterns: Vec<_> = ntokens + .iter() + .map(|token| { + let mut content = NormalizedString::from(token.content.as_ref()); n.normalize(&mut content).unwrap(); - } - content - }) - .collect(); - let normalized_trie = AhoCorasickBuilder::new() - .match_kind(MatchKind::LeftmostLongest) - .build(patterns.iter().map(|content| content.get())) - .expect("Failed to build tried when refreshing tokens (normalized)"); - self.split_normalized_trie = (normalized_trie, nids); + content + }) + .collect(); + let normalized_trie = AhoCorasickBuilder::new() + .match_kind(MatchKind::LeftmostLongest) + .build(patterns.iter().map(|content| content.get())) + .expect("Failed to build tried when refreshing tokens (normalized)"); + self.split_normalized_trie = (normalized_trie, nids); + } else { + self.split_normalized_trie = (trie, ids); // non normalized is the same + } } /// Find any AddedToken in the given sentence, using the provided MatchingSet.