diff --git a/bindings/node/Cargo.toml b/bindings/node/Cargo.toml index cf1e51e99..6e00f9d7c 100644 --- a/bindings/node/Cargo.toml +++ b/bindings/node/Cargo.toml @@ -14,6 +14,7 @@ napi = "2" napi-derive = "2" serde = { version = "1.0.163", features = ["derive"] } tokenizers = { path = "../../tokenizers/" } +ahash = { version = "0.8.11", features = ["serde"] } [build-dependencies] napi-build = "2" diff --git a/bindings/node/src/models.rs b/bindings/node/src/models.rs index a4138b91f..9ee7f60f7 100644 --- a/bindings/node/src/models.rs +++ b/bindings/node/src/models.rs @@ -1,6 +1,7 @@ use crate::arc_rwlock_serde; use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask}; use crate::trainers::Trainer; +use ahash::AHashMap; use napi::bindgen_prelude::*; use napi_derive::napi; use serde::{Deserialize, Serialize}; @@ -8,7 +9,7 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; use tokenizers as tk; -use tokenizers::models::bpe::{BpeBuilder, Merges, Vocab}; +use tokenizers::models::bpe::{BpeBuilder, Merges}; use tokenizers::models::wordlevel::WordLevelBuilder; use tokenizers::models::wordpiece::WordPieceBuilder; @@ -44,8 +45,13 @@ impl Bpe { } #[napi(factory, ts_return_type = "Model")] - pub fn init(vocab: Vocab, merges: Merges, options: Option) -> Result { + pub fn init( + vocab: HashMap, + merges: Merges, + options: Option, + ) -> Result { let options = options.unwrap_or_default(); + let vocab: AHashMap<_, _> = vocab.into_iter().collect(); let mut builder = tk::models::bpe::BPE::builder().vocab_and_merges(vocab, merges); builder = options.apply_to_bpe_builder(builder); let model = builder @@ -206,10 +212,11 @@ pub struct WordPiece {} #[napi] impl WordPiece { #[napi(factory, ts_return_type = "Model")] - pub fn init(vocab: Vocab, options: Option) -> Result { + pub fn init(vocab: HashMap, options: Option) -> Result { let options = options.unwrap_or_default(); - let mut builder = tk::models::wordpiece::WordPiece::builder().vocab(vocab); + let mut builder = tk::models::wordpiece::WordPiece::builder() + .vocab(vocab.into_iter().collect::>()); builder = options.apply_to_wordpiece_builder(builder); let model = builder .build() @@ -263,9 +270,10 @@ pub struct WordLevel {} #[napi] impl WordLevel { #[napi(factory, ts_return_type = "Model")] - pub fn init(vocab: Vocab, options: Option) -> Result { + pub fn init(vocab: HashMap, options: Option) -> Result { let options = options.unwrap_or_default(); - let mut builder = tk::models::wordlevel::WordLevel::builder().vocab(vocab); + let mut builder = + tk::models::wordlevel::WordLevel::builder().vocab(vocab.into_iter().collect()); builder = options.apply_to_wordlevel_builder(builder); let model = builder .build() diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index c0f05ac6b..76f09604a 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,6 +18,7 @@ pyo3 = { version = "0.25", features = ["abi3", "abi3-py39", "py-clone"] } numpy = "0.25" ndarray = "0.16" itertools = "0.14" +ahash = { version = "0.8.11", features = ["serde"] } [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/benches/test_backtrack.py b/bindings/python/benches/test_backtrack.py new file mode 100644 index 000000000..0988d387c --- /dev/null +++ b/bindings/python/benches/test_backtrack.py @@ -0,0 +1,88 @@ +import os +import argparse +import datetime +from datasets import load_dataset +from tokenizers import Tokenizer +from typing import Tuple + +MODEL_ID = "meta-llama/Meta-Llama-3.1-8B" +DATASET = "facebook/xnli" +DATASET_CONFIG = "all_languages" +DEFAULT_THREADS = [2**i for i in range(8) if 2**i <= os.cpu_count()] + + +def format_byte_size(num_bytes: int) -> Tuple[str, str]: + """Convert bytes to a human-readable format (KB, MB, GB).""" + num_bytes_f = float(num_bytes) + for unit in ["B", "KB", "MB", "GB", "TB"]: + if num_bytes_f < 1024: + return f"{num_bytes_f:.2f} {unit}", unit + num_bytes_f /= 1024 + return f"{num_bytes_f:.2f} PB", "PB" + + +def test(model: str, dataset: str, dataset_config: str): + dataset_xnli = load_dataset(dataset, dataset_config) + tokenizer = Tokenizer.from_pretrained(model) + tokenizer2 = Tokenizer.from_pretrained(model) + tokenizer2.enable_backtrack() + + for easy in ["1880", " cream"]: + encoded = tokenizer.encode(easy) + encoded2 = tokenizer2.encode(easy) + if encoded.ids != encoded2.ids: + import ipdb + + ipdb.set_trace() + assert encoded.ids == encoded2.ids + + sentences = [] + en_sentences = [] + for _i, item in enumerate(dataset_xnli["train"]): + # sentence = item["premise"]["en"] + # sentences.append(sentence) + for lang, sentence in item["premise"].items(): + if lang == "en": + en_sentences.append(sentence) + sentences.append(sentence) + sentences = en_sentences + sentences + + start = datetime.datetime.now() + encoded = tokenizer.encode_batch_fast(sentences) + print(f"Took {datetime.datetime.now() - start}") + + start = datetime.datetime.now() + encoded2 = tokenizer2.encode_batch_fast(sentences) + print(f"Took {datetime.datetime.now() - start}") + + assert len(encoded) == len(encoded2) + assert len(encoded) == len(sentences) + total = 0 + correct = 0 + for enc, enc2, sentence in zip(encoded, encoded2, sentences): + # if enc.ids != enc2.ids: + # print(enc.ids) + # print(enc2.ids) + if enc.ids == enc2.ids: + correct += 1 + total += 1 + assert enc.ids == enc2.ids, f"{enc.ids} != {enc2.ids} (Source: {sentence}" + print(f"{correct} / {total} ({correct / total * 100:.2f}%%)") + # print("All good !") + + +def main(): + parser = argparse.ArgumentParser( + prog="bench_tokenizer", + description="Getting a feel for speed when tokenizing", + ) + parser.add_argument("-m", "--model", default=MODEL_ID, type=str) + parser.add_argument("-d", "--dataset", default=DATASET, type=str) + parser.add_argument("-ds", "--dataset-config", default=DATASET_CONFIG, type=str) + args = parser.parse_args() + test(args.model, args.dataset, args.dataset_config) + + +# Call the function to run the benchmark +if __name__ == "__main__": + main() diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 2f4dba825..81d5f4eb6 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -4,11 +4,12 @@ use std::sync::{Arc, RwLock}; use crate::token::PyToken; use crate::trainers::PyTrainer; +use ahash::AHashMap; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; use serde::{Deserialize, Serialize}; -use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE}; +use tk::models::bpe::{BpeBuilder, Merges, BPE}; use tk::models::unigram::Unigram; use tk::models::wordlevel::WordLevel; use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; @@ -347,9 +348,10 @@ macro_rules! setter { #[derive(FromPyObject)] enum PyVocab { - Vocab(Vocab), + Vocab(HashMap), Filename(String), } + #[derive(FromPyObject)] enum PyMerges { Merges(Merges), @@ -454,6 +456,7 @@ impl PyBPE { if let (Some(vocab), Some(merges)) = (vocab, merges) { match (vocab, merges) { (PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => { + let vocab: AHashMap<_, _> = vocab.into_iter().collect(); builder = builder.vocab_and_merges(vocab, merges); } (PyVocab::Filename(vocab_filename), PyMerges::Filename(merges_filename)) => { @@ -494,13 +497,15 @@ impl PyBPE { /// The vocabulary and merges loaded into memory #[staticmethod] #[pyo3(text_signature = "(self, vocab, merges)")] - fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> { - BPE::read_file(vocab, merges).map_err(|e| { + fn read_file(vocab: &str, merges: &str) -> PyResult<(HashMap, Merges)> { + let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| { exceptions::PyException::new_err(format!( "Error while reading vocab & merges files: {}", e )) - }) + })?; + let vocab = vocab.into_iter().collect(); + Ok((vocab, merges)) } /// Instantiate a BPE model from the given files. @@ -536,6 +541,7 @@ impl PyBPE { let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| { exceptions::PyException::new_err(format!("Error while reading BPE files: {}", e)) })?; + let vocab = vocab.into_iter().collect(); Py::new( py, PyBPE::new( @@ -668,6 +674,7 @@ impl PyWordPiece { if let Some(vocab) = vocab { match vocab { PyVocab::Vocab(vocab) => { + let vocab: AHashMap<_, _> = vocab.into_iter().collect(); builder = builder.vocab(vocab); } PyVocab::Filename(vocab_filename) => { @@ -699,10 +706,11 @@ impl PyWordPiece { /// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` #[staticmethod] #[pyo3(text_signature = "(vocab)")] - fn read_file(vocab: &str) -> PyResult { - WordPiece::read_file(vocab).map_err(|e| { + fn read_file(vocab: &str) -> PyResult> { + let vocab = WordPiece::read_file(vocab).map_err(|e| { exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e)) - }) + })?; + Ok(vocab.into_iter().collect()) } /// Instantiate a WordPiece model from the given file @@ -734,6 +742,7 @@ impl PyWordPiece { let vocab = WordPiece::read_file(vocab).map_err(|e| { exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e)) })?; + let vocab = vocab.into_iter().collect(); Py::new( py, PyWordPiece::new(py, Some(PyVocab::Vocab(vocab)), kwargs)?, @@ -778,6 +787,7 @@ impl PyWordLevel { if let Some(vocab) = vocab { match vocab { PyVocab::Vocab(vocab) => { + let vocab = vocab.into_iter().collect(); builder = builder.vocab(vocab); } PyVocab::Filename(vocab_filename) => { @@ -818,10 +828,12 @@ impl PyWordLevel { /// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict` #[staticmethod] #[pyo3(text_signature = "(vocab)")] - fn read_file(vocab: &str) -> PyResult { - WordLevel::read_file(vocab).map_err(|e| { + fn read_file(vocab: &str) -> PyResult> { + let vocab = WordLevel::read_file(vocab).map_err(|e| { exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e)) - }) + })?; + let vocab: HashMap<_, _> = vocab.into_iter().collect(); + Ok(vocab) } /// Instantiate a WordLevel model from the given file @@ -853,6 +865,7 @@ impl PyWordLevel { let vocab = WordLevel::read_file(vocab).map_err(|e| { exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e)) })?; + let vocab = vocab.into_iter().collect(); Py::new( py, PyWordLevel::new(py, Some(PyVocab::Vocab(vocab)), unk_token)?, diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 73a0dbbe8..124319838 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,6 +1,8 @@ use serde::Serialize; use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; +use tk::pre_tokenizers::byte_level::ByteLevel; +use tk::ModelWrapper; use numpy::{npyffi, PyArray1, PyArrayMethods}; use pyo3::class::basic::CompareOp; @@ -1118,6 +1120,19 @@ impl PyTokenizer { .into() }) } + /// + #[pyo3(signature = ())] + #[pyo3(text_signature = "(self)")] + fn enable_backtrack(&mut self) -> PyResult<()> { + // self.tokenizer.with_pre_tokenizer(None::); + let model = self.tokenizer.get_model(); + let mut model = model.model.write().unwrap(); + let ModelWrapper::BPE(ref mut model) = *model else { + todo!(); + }; + model.enable_backtrack(); + Ok(()) + } /// Decode the given list of ids back to a string /// diff --git a/bindings/python/test.py b/bindings/python/test.py new file mode 100644 index 000000000..931a2a353 --- /dev/null +++ b/bindings/python/test.py @@ -0,0 +1,313 @@ +import torch +from transformers import AutoModel +from transformers import AutoTokenizer +from faker import Faker +from huggingface_hub import hf_hub_download +import json + +# Create a Faker instance with Japanese locale +fake = Faker("ja_JP") + + +# Generate random Japanese text +def generate_random_japanese_text(): + return fake.text() + + +def move_to_cuda(sample): + if len(sample) == 0: + return {} + + def _move_to_cuda(maybe_tensor): + if torch.is_tensor(maybe_tensor): + return maybe_tensor.cuda(non_blocking=True) + elif isinstance(maybe_tensor, dict): + return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} + elif isinstance(maybe_tensor, list): + return [_move_to_cuda(x) for x in maybe_tensor] + elif isinstance(maybe_tensor, tuple): + return tuple([_move_to_cuda(x) for x in maybe_tensor]) + # elif isinstance(maybe_tensor, Mapping): + # return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()}) + else: + return maybe_tensor + + return _move_to_cuda(sample) + + +def create_batch_dict(tokenizer, input_texts, max_length: int = 512): + return tokenizer( + input_texts, + max_length=max_length, + padding=True, + pad_to_multiple_of=8, + return_token_type_ids=False, + truncation=True, + return_tensors="pt", + ) + + +def pool(last_hidden_states, attention_mask, pool_type: str): + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + + if pool_type == "avg": + emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + elif pool_type == "weightedavg": # position-weighted mean pooling from SGPT (https://arxiv.org/abs/2202.08904) + attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0] + s = torch.sum(last_hidden * attention_mask.unsqueeze(-1).float(), dim=1) + d = attention_mask.sum(dim=1, keepdim=True).float() + emb = s / d + elif pool_type == "cls": + emb = last_hidden[:, 0] + elif pool_type == "last": + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + emb = last_hidden[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden.shape[0] + emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] + else: + raise ValueError(f"pool_type {pool_type} not supported") + + return emb + + +class KVEmbedding: + def __init__(self, device): + self.device = device + + # Load tokenizer and model from pretrained multilingual-e5-small + self.tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-small") + self.model = AutoModel.from_pretrained("intfloat/multilingual-e5-small").to(self.device) + + self.model.eval() # Set model to evaluation mode + + def average_pool(self, last_hidden_states, attention_mask): + # Apply mask to hidden states, set masked positions to 0 + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + # Average the hidden states along the sequence dimension + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + def embedding(self, l_transcription, batch_size=32): + # Tokenize input transcriptions + batch_dict = self.tokenizer( + l_transcription, + max_length=512, + padding=True, + truncation=True, + return_tensors="pt", + ).to(self.device) + + return batch_dict + + def _do_encode(self, input_texts): + encoded_embeds = [] + batch_size = 64 + for start_idx in range(0, len(input_texts), batch_size): + batch_input_texts = input_texts[start_idx : start_idx + batch_size] + + batch_dict = create_batch_dict(self.tokenizer, batch_input_texts) + # batch_dict = move_to_cuda(batch_dict) + return encoded_embeds + + +import random +from faker import Faker + +# # Lists of Japanese characters +hiragana = [ + "あ", + "い", + "う", + "え", + "お", + "か", + "き", + "く", + "け", + "こ", + "さ", + "し", + "す", + "せ", + "そ", + "た", + "ち", + "つ", + "て", + "と", + "な", + "に", + "ぬ", + "ね", + "の", + "は", + "ひ", + "ふ", + "へ", + "ほ", + "ま", + "み", + "む", + "め", + "も", + "や", + "ゆ", + "よ", + "ら", + "り", + "る", + "れ", + "ろ", + "わ", + "を", + "ん", +] +katakana = [ + "ア", + "イ", + "ウ", + "エ", + "オ", + "カ", + "キ", + "ク", + "ケ", + "コ", + "サ", + "シ", + "ス", + "セ", + "ソ", + "タ", + "チ", + "ツ", + "テ", + "ト", + "ナ", + "ニ", + "ヌ", + "ネ", + "ノ", + "ハ", + "ヒ", + "フ", + "ヘ", + "ホ", + "マ", + "ミ", + "ム", + "メ", + "モ", + "ヤ", + "ユ", + "ヨ", + "ラ", + "リ", + "ル", + "レ", + "ロ", + "ワ", + "ヲ", + "ン", +] +kanji = [ + "日", + "本", + "語", + "学", + "校", + "生", + "時", + "間", + "人", + "大", + "小", + "中", + "山", + "川", + "口", + "目", + "耳", + "手", + "足", + "力", + "男", + "女", + "子", + "父", + "母", +] + +# Combine all character sets +all_characters = hiragana + katakana + kanji + + +# Generate random Japanese text +def generate_random_japanese(length): + return "".join(random.choices(all_characters, k=length)) + + +def remove_invalid_characters(valid_chars, text): + """ + Removes all invalid characters from the given text, keeping only the characters present in char_dicts. + + Args: + char_dicts (dict): Dictionary of valid characters. + text (str): Input text string. + + Returns: + str: Text string with only valid characters. + """ + # Convert dict keys to a set for faster lookup + filtered_text = "".join(c for c in text if c in valid_chars) + return filtered_text + + +if __name__ == "__main__": + from tqdm import tqdm + import psutil + + print("Start app ...") + filename = hf_hub_download("intfloat/multilingual-e5-small", "tokenizer.json") + with open(filename, "r") as file: + character_info = json.load(file) + character_dict = {} + print("Vocab is loading ...") + with tqdm(total=100, desc="cpu%", position=1) as cpubar, tqdm(total=100, desc="ram%", position=0) as rambar: + for data in character_info["model"]["vocab"]: + character_dict[data[0]] = data[1] + valid_chars = set(character_dict.keys()) + print("Start loading model") + kv_embedding = KVEmbedding("cpu") + print("Loading model: Done!!!") + for i in range(7500): + print(f"============{i}==============") + length = random.randint(600, 1000) + # print(length) + input_texts = [] + for s in range(length): + text_length = random.randint(1, 10000) + + random_text = generate_random_japanese(text_length) + + # before = len(random_text) + random_text = remove_invalid_characters(valid_chars, random_text) + # after = len(random_text) + # if after != before: + # print(before, after) + random_text = random_text[:450] + input_texts.append(random_text) + rambar.n = psutil.virtual_memory().percent + cpubar.n = psutil.cpu_percent() + rambar.refresh() + cpubar.refresh() + + filter_output = input_texts[:512] + + del input_texts + + # print(len(filter_output)) + + output = kv_embedding.embedding(filter_output) diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index b24715862..3e444e333 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -66,6 +66,11 @@ fancy-regex = { version = "0.14", optional = true} getrandom = { version = "0.3" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" +ahash = { version = "0.8.11", features = ["serde"] } +dary_heap = { version = "0.3.6", features = ["serde"] } +compact_str = { version = "0.9", features = ["serde"] } +fnv = "1.0.7" +aneubeck-daachorse = "1.1.1" [features] default = ["progressbar", "onig", "esaxx_fast"] diff --git a/tokenizers/src/models/bpe/backtrack.rs b/tokenizers/src/models/bpe/backtrack.rs new file mode 100644 index 000000000..f419b6506 --- /dev/null +++ b/tokenizers/src/models/bpe/backtrack.rs @@ -0,0 +1,716 @@ +use crate::decoders::byte_level::CHAR_BYTES; +use crate::models::bpe::Pair; +use crate::pre_tokenizers::byte_level::ByteLevel; +use crate::pre_tokenizers::byte_level::BYTES_CHAR; +use crate::tokenizer::{Decoder, Result}; +use ahash::AHashMap; +use aneubeck_daachorse::DoubleArrayAhoCorasick; +use aneubeck_daachorse::DoubleArrayAhoCorasickBuilder; +use fnv::{FnvHashMap, FnvHasher}; +use itertools::Itertools; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::hash::{Hash, Hasher}; +use std::ops::Range; + +use super::MergeMap; +use super::Merges; +use super::Vocab; +use super::VocabR; + +/// Small helper to manage a bit field which supports predecessor and successor queries with a simple scan implementation. +/// This is sufficient for our use case, since two one bits will be at most 128 bits apart. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct BitField { + bitfield: Vec, +} + +impl BitField { + /// All bits are initialized to 1. + pub(crate) fn new(bits: usize) -> Self { + Self { + bitfield: vec![u64::MAX; (bits + 63) / 64], + } + } + + pub(crate) fn is_set(&self, bit: usize) -> bool { + let (word, bit) = (bit / 64, bit % 64); + self.bitfield[word] & (1 << bit) != 0 + } + + pub(crate) fn clear(&mut self, bit: usize) { + let (word, bit) = (bit / 64, bit % 64); + self.bitfield[word] &= !(1 << bit); + } + + pub(crate) fn successor(&self, bit: usize) -> usize { + let (mut word_idx, bit_idx) = (bit / 64, bit % 64); + let word = self.bitfield[word_idx] >> bit_idx; + if word != 0 { + word.trailing_zeros() as usize + bit + } else { + loop { + word_idx += 1; + let word = self.bitfield[word_idx]; + if word != 0 { + break word.trailing_zeros() as usize + word_idx * 64; + } + } + } + } + + pub(crate) fn predecessor(&self, bit: usize) -> usize { + let (mut word_idx, bit_idx) = (bit / 64, bit % 64); + let word = self.bitfield[word_idx] << (63 - bit_idx); + if word != 0 { + bit - word.leading_zeros() as usize + } else { + loop { + word_idx -= 1; + let word = self.bitfield[word_idx]; + if word != 0 { + break word_idx * 64 + 63 - word.leading_zeros() as usize; + } + } + } + } +} + +/// This can be thought of as a lazy variation of the dynamic programming approach. +/// It only computes those states which have to be visited in order to compute the tokenization +/// for a given input text. +/// It keeps track of visited states in a bitfield and only remembers the tokenization +/// of the currently processed dynamic programming state. +/// +/// The biggest downside of this approach is that the search for the longest leftmost match (the firt token?) +/// has to be reset at every (backtracking) step which is still a net win in practice compared to other approaches. +#[derive(Clone, PartialEq)] +pub struct BacktrackState<'a> { + pub(crate) text: &'a [u8], + pub(crate) tokens: Vec, // len of the tezt / 3 + pub(crate) next_token: Option, // bpe.next_match(text) wich is longest_searcher.leftmost_find_iter(text)'s first match value + pub(crate) pos: usize, // current pos in the text? + pub(crate) bitfield: BitField, // keeps track of token boundaries? keeps track of all the valid tokenization positions and making the runtime linear in the input length. +} + +impl<'a> BacktrackState<'a> { + pub(crate) fn new(text: &'a [u8], next_token: Option) -> Self { + Self::with_capacity(text, next_token, text.len() / 3) + } + + pub(crate) fn with_capacity(text: &'a [u8], next_token: Option, cap: usize) -> Self { + Self { + text, + tokens: Vec::with_capacity(cap), + next_token, + pos: 0, + bitfield: BitField::new(text.len() + 1), + } + } + pub(crate) fn count(&self) -> usize { + self.tokens.len() + } + + pub(crate) fn pos(&self) -> usize { + self.pos + } + + pub(crate) fn last_token(&self) -> Option { + self.tokens.last().copied() + } + + pub(crate) fn into_tokens(self) -> Vec { + self.tokens + } +} + +#[derive(PartialEq, Clone)] +pub struct Backtrack { + /// All the decoded tokens concatenated into? used to build the aho corasick searchers + all_tokens: Vec, + /// Start index of each token in all_tokens. + /// The end is simply the next entry in this vector. + token_starts: Vec, + /// Mapping from hash of token to token id. + bytes_hash_to_token: FnvHashMap, + /// The two tokens from which the token got merged. + /// If the token is an original one, than the two tokens point back to itself. + split_table: Vec<(u32, u32)>, + /// Mapping from a pair of tokens to a merged token if such a merged token exists. + pair_lookup: FnvHashMap<(u32, u32), u32>, + /// An aho corasick automaton to find the next longest token in a byte sequence. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + longest_searcher: DoubleArrayAhoCorasick, + /// An aho corasick automaton to find ALL tokens in a byte sequence. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + pub(crate) overlapping_searcher: DoubleArrayAhoCorasick, + /// An aho corasick automaton to find ALL tokens in a byte sequence which is being processed in reverse order. + // #[serde( + // serialize_with = "serialize_daac", + // deserialize_with = "deserialize_daac" + // )] + pub(crate) overlapping_searcher_rev: DoubleArrayAhoCorasick, + /// Mapping from a token to the next longest prefix token. + /// This is in principle information represented by the AhoCorasick automaton. + /// But we don't have efficient access to it and therefore store it here again. + /// If there is none, then the value is set to u32::MAX. + next_prefix_match: Vec, + /// Hash factor used to prevent hash collisions. + hash_factor: u64, + vocab: Vocab, + vocab_r: VocabR, + unk_token: Option, + merges: MergeMap, +} + +fn hash_bytes(bytes: &[u8], factor: u64) -> u32 { + let mut hasher = FnvHasher::default(); + bytes.hash(&mut hasher); + // Note: we save 1/3 of space for the hashmap by only using the most significant bits of the hash. + // To make them unique for the given tokens, we have to add unfortunately another multiplication. + ((hasher.finish().wrapping_mul(factor)) >> 32) as u32 +} + +// #[cfg(feature = "rand")] +pub fn find_hash_factor_for_dictionary(tokens: impl IntoIterator>) -> u64 { + use std::collections::HashSet; + + use rand::Rng; + + let all_tokens: Vec> = tokens.into_iter().collect(); + let mut rnd = rand::rng(); + loop { + let factor: u64 = rnd.random(); + let mut seen = HashSet::new(); + if all_tokens + .iter() + .all(|token| seen.insert(hash_bytes(token, factor))) + { + return factor; + } + } +} + +impl Backtrack { + pub(crate) fn new(vocab: Vocab, merge_map: MergeMap) -> Self { + // let vocab_vec: Vec<_> = vocab + // .into_iter() + // .sorted_unstable_by(|a, b| a.1.cmp(&b.1)) + // .map(|(k, _v)| k.chars().map(|b| CHAR_BYTES[&b] as u8).collect::>()) + // .collect(); + let mut merges: Vec<_> = merge_map.values().collect(); + merges.sort(); + let merge_vocab: Vec = merges + .into_iter() + .map(|(_rank, token_id)| *token_id) + .collect(); + + let vocab_r: AHashMap<_, _> = vocab.iter().map(|(k, v)| (v, k)).collect(); + let mut tokens: Vec<_> = vocab + .clone() + .into_iter() + .flat_map(|(k, token_id)| { + if merge_vocab.contains(&token_id) { + Some((token_id, k)) + } else { + None + } + }) + .collect(); + tokens.sort(); + let mut tokens: Vec<_> = tokens.into_iter().map(|(_token_id, k)| k).collect(); + + let merge_vocab: Vec = merge_vocab + .into_iter() + .map(|token_id| vocab_r[&token_id].clone()) + .collect(); + tokens.extend(merge_vocab); + let vocab_vec: Vec<_> = tokens.into_iter().map(|k| k.as_bytes().to_vec()).collect(); + + let hash_factor = find_hash_factor_for_dictionary(vocab_vec.clone()); + let mut all_tokens = Vec::new(); + let mut all_tokens_rev = Vec::new(); + let mut token_starts = vec![0]; // The begin byte index of each token in all_tokens. + let mut bytes_hash_to_token = FnvHashMap::default(); + let tokens = vocab_vec; + for (i, token) in tokens.into_iter().enumerate() { + info!( + "token byte: {:?}, {i}", + ByteLevel::default() + .decode_chain(unsafe { vec![String::from_utf8_unchecked(token.clone())] }) + .unwrap() + ); + bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32); + all_tokens_rev.extend(token.iter().copied().rev()); + all_tokens.extend(token); + token_starts.push(all_tokens.len() as u32); + } + assert_eq!( + bytes_hash_to_token.len() + 1, + token_starts.len(), + "Some tokens are not unique under the hash function!" + ); // TODO maybe this check is needed? + let longest_searcher = DoubleArrayAhoCorasickBuilder::new() + .match_kind(aneubeck_daachorse::MatchKind::LeftmostLongest) + .build(token_iter(&all_tokens, &token_starts)) + .expect("failed to build AhoCorasick"); + + let overlapping_searcher = + DoubleArrayAhoCorasick::::new(token_iter(&all_tokens, &token_starts)).expect(""); + let overlapping_searcher_rev = + DoubleArrayAhoCorasick::::new(token_iter(&all_tokens_rev, &token_starts)) + .expect(""); + + let next_prefix_match: Vec<_> = token_iter(&all_tokens, &token_starts) + .map(|token| { + next_match(&longest_searcher, &token[0..token.len() - 1]).unwrap_or(u32::MAX) + }) + .collect(); + + let vocab: AHashMap = token_iter(&all_tokens, &token_starts) + .enumerate() + .map(|(id, bytes)| { + ( + bytes.iter().map(|b| BYTES_CHAR[b]).collect::(), + id as u32, + ) + }) + .collect(); + + let vocab_r: AHashMap = token_iter(&all_tokens, &token_starts) + .enumerate() + .map(|(id, bytes)| { + ( + id as u32, + bytes.iter().map(|b| BYTES_CHAR[b]).collect::(), + ) + }) + .collect(); + + let mut split_table = vec![]; + let mut pair_lookup = FnvHashMap::default(); + let mut merge_map = AHashMap::new(); + + // // First option, use the input merge table. + // if let Some(ref merges) = merges { + // for (index, pair) in merges.into_iter().enumerate() { + // let token1 = &pair.0.clone(); + // let token2 = &pair.1.clone(); + // // TODO something is weird here + // if token1.len() ==1{ + // split_table.push((vocab[token1], vocab[token1])); + // } + // if token2.len() == 1 { + // split_table.push((vocab[token2], vocab[token2])); + // } + // let id1 = vocab[token1]; + // let id2 = vocab[token2]; + // let new_token = format!("{}{}", token1, &token2); + // let new_id = vocab + // .get(&new_token) + // .ok_or(Error::MergeTokenOutOfVocabulary(new_token)); + // if let Ok(id) = new_id { + // pair_lookup.insert((id1, id2), *id); + // split_table.push((id1, id2)); + // merge_map.insert(Pair::from((id1, id2)), (index as u32, *id)); + // } else { + // println!("Token not added?"); + // } + + // // TODO wrong + // } + // split_table.push((merges.len() as u32, merges.len() as u32)); + // } + // Second option, reverse engineer the merge/split table from the vocabulary. + { + for (id, token) in token_iter(&all_tokens, &token_starts).enumerate() { + let mut id1 = next_prefix_match[id]; + while id1 != u32::MAX { + let rest = &token[token_range(&token_starts, id1).len()..]; + if let Some(id2) = find_token_by_bytes( + &all_tokens, + &token_starts, + &bytes_hash_to_token, + rest, + hash_factor, + ) { + if id1 < id as u32 + && id2 < id as u32 + && is_valid_token_pair(&pair_lookup, &split_table, id1, id2) + { + pair_lookup.insert((id1, id2), id as u32); + split_table.push((id1, id2)); + merge_map.insert(Pair::from((id1, id2)), (id as u32, id as u32)); + break; + } + } + id1 = next_prefix_match[id1 as usize]; + } + if id1 == u32::MAX { + split_table.push((id as u32, id as u32)); + } + } + }; + let bpe = Self { + all_tokens, + token_starts, + bytes_hash_to_token, + overlapping_searcher, + overlapping_searcher_rev, + longest_searcher, + next_prefix_match, + pair_lookup, + split_table, + hash_factor, + unk_token: None, + vocab, + vocab_r, + merges: merge_map, + }; + // A health checkup + for token_id in 0..bpe.num_tokens() as u32 { + let bytes = bpe.token_bytes(token_id); + let strs = bytes.iter().map(|b| char::from(*b)).collect::>(); + // println!("Encoding {bytes:?} into bitfield"); + let tokens = bpe.encode_via_bitfield(bytes); + assert_eq!( + tokens, + vec![token_id], + "token {token_id} with bytes {bytes:?} (tokens {strs:?} encodes to {tokens:?} instead of to itself" + ); + } + bpe + } + + fn bitfield_into_tokens(&self, bytes: &[u8], bitfield: BitField, count: usize) -> Vec { + let mut encoded = Vec::with_capacity(count); + let mut start = 0; + while start < bytes.len() { + let end = bitfield.successor(start + 1); + // println!("bitfield's successor {:?}", &bytes[start..end]); + let token = self + .find_token_by_bytes(&bytes[start..end]) + .expect(&format!( + "Could not convert bytes to tokens for bytes: [{:?}]", + bytes.into_iter().map(|b| BYTES_CHAR[b]).join("") + )); + encoded.push(token); + start = end; + } + encoded + } + + fn encode_into_bitfield(&self, bytes: &[u8]) -> (BitField, usize) { + // Reserve for every byte a bit in the bitfield. + let mut bitfield = BitField::new(bytes.len() + 1); + let mut heap = BinaryHeap::with_capacity(bytes.len() * 2); + heap.extend((0..bytes.len().saturating_sub(1)).filter_map(|i| { + self.find_token_by_bytes(&bytes[i..i + 2]) + .map(|e| Reverse((e, i as u32))) + })); + let mut count = bytes.len(); + while let Some(Reverse((token, start))) = heap.pop() { + let start = start as usize; + if !bitfield.is_set(start) { + continue; + } + let mid = bitfield.successor(start + 1); + if mid >= bytes.len() { + continue; + } + let end = bitfield.successor(mid + 1); + if self.token_len(token) != end - start { + continue; + } + bitfield.clear(mid); + count -= 1; + if end < bytes.len() { + let new_end = bitfield.successor(end + 1); + if let Some(e) = self.find_token_by_bytes(&bytes[start..new_end]) { + heap.push(Reverse((e, start as u32))); + } + } + if start > 0 { + let new_start = bitfield.predecessor(start - 1); + if let Some(e) = self.find_token_by_bytes(&bytes[new_start..end]) { + heap.push(Reverse((e, new_start as u32))); + } + } + } + (bitfield, count) + } + + pub fn encode_via_bitfield(&self, text: &[u8]) -> Vec { + let (bitfield, count) = self.encode_into_bitfield(text); + self.bitfield_into_tokens(text, bitfield, count) + } + + /// Return the number of tokens in this BPE dictionary. + pub fn num_tokens(&self) -> usize { + self.token_starts.len() - 1 + } + + /// Converts a token id into its corresponding token bytes. + /// Panics if the token_id is not within the valid 0..num_tokens() range! + pub fn token_bytes(&self, token_id: u32) -> &[u8] { + token_bytes(&self.all_tokens, &self.token_starts, token_id) + } + + pub(crate) fn is_valid_token_pair(&self, token1: u32, token2: u32) -> bool { + is_valid_token_pair(&self.pair_lookup, &self.split_table, token1, token2) + } + + /// Returns the length of the decoded byte slice of a token. + pub fn token_len(&self, token_id: u32) -> usize { + token_range(&self.token_starts, token_id).len() + } + + /// Returns the first longest match in the provided text. + pub(crate) fn next_match(&self, text: &[u8]) -> Option { + next_match(&self.longest_searcher, text) + } + + /// Returns the next token which shares the longest prefix with the specified token. + pub(crate) fn next_prefix(&self, token_id: u32) -> Option { + let prefix = self.next_prefix_match[token_id as usize]; + if prefix == u32::MAX { + None + } else { + Some(prefix) + } + } + + fn find_token_by_bytes(&self, bytes: &[u8]) -> Option { + find_token_by_bytes( + &self.all_tokens, + &self.token_starts, + &self.bytes_hash_to_token, + bytes, + self.hash_factor, + ) + } + + /// Decode a sequence of tokens back to its original byte sequence. + /// Note: we don't return here a str, since not every token sequence corresponds to a valid + /// utf8 sequence. + pub fn decode_tokens(&self, tokens: &[u32]) -> Vec { + let mut text = vec![]; + for token in tokens { + text.extend(self.token_bytes(*token)); + } + text + } + + /// Computes for every prefix of the input text a corresponding last token. + pub(crate) fn encode_all_prefixes(&self, text: &[u8]) -> Vec { + let mut last_token = Vec::with_capacity(text.len()); + let mut state = self.overlapping_searcher.start_state(); + for (pos, c) in text.iter().enumerate() { + let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c); + state = s; + for m in iter { + let new_token = m.value(); + let new_range = m.start()..m.end(); + assert_eq!(new_range.end, last_token.len() + 1); + if new_range.start == 0 { + last_token.push(new_token); + break; + } else { + let prev_token = unsafe { *last_token.get_unchecked(new_range.start - 1) }; + if self.is_valid_token_pair(prev_token, new_token) { + last_token.push(new_token); + break; + } + // println!("Finished encoding prefix") + } + } + } + last_token + } + + /// Counts the number tokens produced when encoding the text. + pub fn count(&mut self, text: &[u8]) -> usize { + let mut enc = BacktrackState::new(text, None); + while self.step(&mut enc).is_some() {} + enc.count() + } + + pub fn encode_via_table(&self, text: &[u8]) -> Vec { + let last_token = self.encode_all_prefixes(text); + let mut encoded = Vec::with_capacity(text.len() / 3); + let mut pos = text.len(); + while pos > 0 { + let token = last_token[pos - 1]; + encoded.push(token); + pos -= self.token_len(token); + } + encoded.reverse(); + encoded + } + + pub fn encode_via_backtracking(&self, text: &[u8]) -> Vec { + let next_token = self.next_match(text); + let mut enc = BacktrackState::new(text, next_token); + while self.step(&mut enc).is_some() {} + enc.into_tokens() + } + + pub fn get_vocab(&self) -> Vocab { + self.vocab.clone() + } + + pub fn get_unk_token(&self) -> &Option { + &self.unk_token + } + + pub fn step(&self, backtrack_state: &mut BacktrackState) -> Option { + let mut token = backtrack_state.next_token?; + let last = backtrack_state.tokens.last().copied(); + loop { + // println!("in step, token: {last:?}, {token}"); + let token_len = self.token_len(token); + let end_pos = backtrack_state.pos + token_len; + if backtrack_state.bitfield.is_set(end_pos) + && last + .map(|last_token| self.is_valid_token_pair(last_token, token)) + .unwrap_or(true) + { + backtrack_state.tokens.push(token); + backtrack_state.pos = end_pos; + // In principle, we could in some cases reuse the leftmost longest match iterator. + // Especially when it has to look ahead, this could save scanning the input multiple times. + // But on average this seems to be slower due to the overhead of storing the iterator as part of the struct. + backtrack_state.next_token = self.next_match(&backtrack_state.text[end_pos..]); + break; + } else if let Some(shorter) = self.next_prefix(token) { + token = shorter; + } else { + // Clearing the bitfield when we pop tokens saves a little bit of work... + backtrack_state.bitfield.clear(backtrack_state.pos); + backtrack_state.tokens.pop(); + backtrack_state.pos -= last.map(|t| self.token_len(t)).unwrap_or(0); + backtrack_state.next_token = last; + break; + } + } + // println!("finished step, token: {last:?}, {token}"); + + backtrack_state.next_token + } +} + +// A helper function to iterate over the tokens in a byte sequence +fn token_iter<'a>(all_tokens: &'a [u8], token_starts: &'a [u32]) -> impl Iterator { + token_starts + .iter() + .tuple_windows() + .map(move |(start, end)| &all_tokens[*start as usize..*end as usize]) +} + +fn next_match(longest_searcher: &DoubleArrayAhoCorasick, text: &[u8]) -> Option { + longest_searcher + .leftmost_find_iter(text) + .map(|m| m.value()) + .next() +} + +fn is_valid_token_pair( + pair_lookup: &FnvHashMap<(u32, u32), u32>, + split_table: &[(u32, u32)], + mut token1: u32, + mut token2: u32, +) -> bool { + // Keep track of the maximum token which can still be chosen across the split point. + let mut limit = u32::MAX; + // println!("checking if {token1}, {token2} is a valid token_pair"); + loop { + // Check whether BPE would choose a different token pair across the split point. + // this is super super important + if let Some(combined) = pair_lookup.get(&(token1, token2)) { + if *combined < limit { + // println!("Done1"); + return false; + } + } + // Reverse the merge operation from BPE. + + // println!("{:?}", split_table); + if token1 > token2 { + limit = token1; + token1 = unsafe { split_table.get_unchecked(token1 as usize).1 }; + if token1 == limit { + limit = token2 + 1; + token2 = unsafe { split_table.get_unchecked(token2 as usize).0 }; + if token2 + 1 == limit { + // println!("Done2"); + return true; + } + } + } else { + limit = token2 + 1; + token2 = unsafe { split_table.get_unchecked(token2 as usize).0 }; + if token2 + 1 == limit { + limit = token1; + token1 = unsafe { split_table.get_unchecked(token1 as usize).1 }; + if token1 == limit { + // println!("Done3"); + return true; + } + } + } + } +} + +fn token_range(token_starts: &[u32], token_id: u32) -> Range { + unsafe { + *token_starts.get_unchecked(token_id as usize) as usize + ..*token_starts.get_unchecked(token_id as usize + 1) as usize + } +} + +fn token_bytes<'a>(all_tokens: &'a [u8], token_starts: &[u32], token_id: u32) -> &'a [u8] { + &all_tokens[token_range(token_starts, token_id)] +} + +fn find_token_by_bytes( + all_tokens: &[u8], + token_starts: &[u32], + bytes_hash_to_token: &FnvHashMap, + bytes: &[u8], + hash_factor: u64, +) -> Option { + let hash = hash_bytes(bytes, hash_factor); + let token = *bytes_hash_to_token.get(&hash)?; + if token_bytes(all_tokens, token_starts, token) == bytes { + Some(token) + } else { + None + } +} + +/// Converts the merges strings (for example from `merges.txt` file) with the format +/// "{pair_a} {pair_b}" into the format expected by the BacktrackingBpe struct +pub(crate) fn convert_merges_to_hashmap>( + iter: I, + _vocab: &Vocab, +) -> Result { + let mut merges = vec![]; + + let lines = iter.filter(|l| !l.starts_with("#version")); + for (rank, line) in lines.enumerate() { + let parts = line.split(' ').collect::>(); + if parts.len() != 2 { + return Err(super::Error::BadMerges(rank + 1).into()); + } + + merges.push((parts[0].to_string(), parts[1].to_string())); + } + + Ok(merges) +} diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index f0d40b2df..a176fe365 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -1,6 +1,7 @@ //! [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. use std::{iter, mem}; +mod backtrack; mod model; mod serialization; pub mod trainer; diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 50c9815e9..d0e1cf842 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,20 +1,23 @@ +use super::backtrack::Backtrack; use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; +use ahash::AHashMap; use serde_json::Value; use std::borrow::Cow; + +use std::collections::HashMap; use std::{ - collections::HashMap, fs::File, io::prelude::*, io::{BufRead, BufReader}, path::{Path, PathBuf}, }; -pub type Vocab = HashMap; -type VocabR = HashMap; -pub type MergeMap = HashMap; +pub type Vocab = AHashMap; +pub type VocabR = AHashMap; +pub type MergeMap = AHashMap; pub type Merges = Vec<(String, String)>; struct Config { @@ -41,7 +44,7 @@ impl Default for BpeBuilder { Self { config: Config { files: None, - vocab: HashMap::new(), + vocab: AHashMap::new(), merges: vec![], cache_capacity: DEFAULT_CACHE_CAPACITY, dropout: None, @@ -71,8 +74,41 @@ impl BpeBuilder { /// Set the vocab (token -> ID) and merges mappings. #[must_use] - pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self { - self.config.vocab = vocab; + pub fn vocab_and_merges>>( + mut self, + vocab: V, + merges: Merges, + ) -> Self { + self.config.vocab = vocab.into(); + // for (i, (left, right)) in merges.iter().enumerate() { + // // println!("{left:?} - {right:?}"); + // let mut result = left.clone(); + // result.push_str(right); + // if result == String::from("Ġvi") { + // println!("Original merge {left:?} - {right:?} - {i}"); + // // panic!("Stop"); + // } + // if result == String::from("á»") { + // println!("Original merge {left:?} - {right:?}"); + // // panic!("Stop"); + // } + // if result == String::from("á»ĩ") { + // println!("Original merge {left:?} - {right:?}"); + // // panic!("Stop"); + // } + // if result == String::from("iá»ĩ") { + // println!("Original merge {left:?} - {right:?}"); + // // panic!("Stop"); + // } + // if result == String::from("iá»ĩc") { + // println!("Original merge {left:?} - {right:?}"); + // // panic!("Stop"); + // } + // if result == String::from("Ġviá»ĩc") { + // println!("Original merge {left:?} - {right:?}"); + // // panic!("Stop"); + // } + // } self.config.merges = merges; self } @@ -199,6 +235,7 @@ impl BpeBuilder { fuse_unk: self.config.fuse_unk, byte_fallback: self.config.byte_fallback, ignore_merges: self.config.ignore_merges, + backtrack: None, }) } } @@ -230,6 +267,8 @@ pub struct BPE { pub byte_fallback: bool, /// Whether or not to direct output words if they are part of the vocab. pub ignore_merges: bool, + + backtrack: Option, } impl std::fmt::Debug for BPE { @@ -271,6 +310,7 @@ impl Clone for BPE { fuse_unk: self.fuse_unk, byte_fallback: self.byte_fallback, ignore_merges: self.ignore_merges, + backtrack: None, } } } @@ -324,7 +364,7 @@ impl BPE { let mut buffer = String::new(); vocab_file.read_to_string(&mut buffer)?; let json: Value = serde_json::from_str(&buffer)?; - let mut vocab = HashMap::new(); + let mut vocab = AHashMap::new(); match json { Value::Object(m) => { for (token, id) in m { @@ -361,8 +401,16 @@ impl BPE { } } - pub fn get_vocab(&self) -> Vocab { - self.vocab.clone() + pub fn get_vocab(&self) -> HashMap { + self.vocab.clone().into_iter().collect() + } + + pub fn get_vocab_r(&self) -> HashMap { + self.vocab_r.clone().into_iter().collect() + } + + pub fn get_merges(&self) -> &AHashMap { + &self.merges } pub fn get_unk_token(&self) -> &Option { @@ -455,7 +503,10 @@ impl BPE { word.add(unk_id, unk_len); } + // println!("Word {word:?}"); + word.merge_all(&self.merges, self.dropout); + // println!("After Word {word:?}"); Ok(word) } @@ -488,13 +539,17 @@ impl BPE { } Ok(ret) } + + pub fn enable_backtrack(&mut self) { + self.backtrack = Some(Backtrack::new(self.vocab.clone(), self.merges.clone())); + } } impl Model for BPE { type Trainer = BpeTrainer; fn get_vocab(&self) -> HashMap { - self.vocab.clone() + self.vocab.clone().into_iter().collect() } fn get_vocab_size(&self) -> usize { @@ -502,10 +557,24 @@ impl Model for BPE { } fn tokenize(&self, sequence: &str) -> Result> { + // println!("Tokenizing {sequence}"); if sequence.is_empty() { return Ok(vec![]); } + if let Some(backtrack) = &self.backtrack { + let ids = backtrack.encode_via_backtracking(sequence.as_bytes()); + let tokens = ids + .into_iter() + .map(|id| Token { + id, + value: self.vocab_r[&id].clone(), + offsets: (0, 0), + }) + .collect(); + return Ok(tokens); + } + if self.dropout.is_none() || self.dropout == Some(0.0) { self.tokenize_with_cache(sequence) } else { diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cc15102..98cf54944 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -1,10 +1,10 @@ use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE}; +use ahash::AHashMap; use serde::{ de::{Error, MapAccess, Visitor}, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer, }; -use std::collections::HashMap; impl Serialize for BPE { fn serialize(&self, serializer: S) -> Result @@ -80,7 +80,7 @@ impl<'de> Visitor<'de> for BPEVisitor { V: MapAccess<'de>, { let mut builder = BpeBuilder::new(); - let mut vocab: Option> = None; + let mut vocab: Option> = None; #[derive(Debug, Deserialize)] #[serde(untagged)] diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 2484865be..50cc52099 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -4,15 +4,17 @@ use super::{Pair, WithFirstLastIterator, Word, BPE}; use crate::parallelism::*; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use ahash::{AHashMap, AHashSet}; +use compact_str::CompactString; +use dary_heap::OctonaryHeap; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap, HashSet}; #[derive(Debug, Eq)] struct Merge { pair: Pair, count: u64, - pos: HashSet, + pos: AHashSet, } impl PartialEq for Merge { fn eq(&self, other: &Self) -> bool { @@ -41,7 +43,7 @@ struct Config { show_progress: bool, special_tokens: Vec, limit_alphabet: Option, - initial_alphabet: HashSet, + initial_alphabet: AHashSet, continuing_subword_prefix: Option, end_of_word_suffix: Option, max_token_length: Option, @@ -62,7 +64,7 @@ impl Default for BpeTrainerBuilder { show_progress: true, special_tokens: vec![], limit_alphabet: None, - initial_alphabet: HashSet::new(), + initial_alphabet: AHashSet::new(), continuing_subword_prefix: None, end_of_word_suffix: None, max_token_length: None, @@ -114,7 +116,7 @@ impl BpeTrainerBuilder { /// Set the initial alphabet #[must_use] - pub fn initial_alphabet(mut self, alphabet: HashSet) -> Self { + pub fn initial_alphabet(mut self, alphabet: AHashSet) -> Self { self.config.initial_alphabet = alphabet; self } @@ -151,7 +153,7 @@ impl BpeTrainerBuilder { continuing_subword_prefix: self.config.continuing_subword_prefix, end_of_word_suffix: self.config.end_of_word_suffix, max_token_length: self.config.max_token_length, - words: HashMap::new(), + words: AHashMap::new(), } } } @@ -187,7 +189,7 @@ pub struct BpeTrainer { pub limit_alphabet: Option, /// The initial alphabet we want absolutely to include. This allows to cover /// some characters that are not necessarily in the training set - pub initial_alphabet: HashSet, + pub initial_alphabet: AHashSet, /// An optional prefix to use on any subword that exist only behind another one pub continuing_subword_prefix: Option, /// An optional suffix to characterize and end-of-word subword @@ -195,7 +197,7 @@ pub struct BpeTrainer { /// An optional parameter to limit the max length of any single token pub max_token_length: Option, - words: HashMap, + words: AHashMap, } impl Default for BpeTrainer { @@ -251,11 +253,16 @@ impl BpeTrainer { } /// Add the provided special tokens to the initial vocabulary - fn add_special_tokens(&self, w2id: &mut HashMap, id2w: &mut Vec) { + fn add_special_tokens( + &self, + w2id: &mut AHashMap, + id2w: &mut Vec, + ) { for token in &self.special_tokens { - if !w2id.contains_key(&token.content) { - id2w.push(token.content.to_owned()); - w2id.insert(token.content.to_owned(), (id2w.len() - 1) as u32); + // get hash of content + if !w2id.contains_key(&CompactString::from(&token.content)) { + id2w.push(CompactString::from(&token.content)); + w2id.insert(CompactString::from(&token.content), (id2w.len() - 1) as u32); } } } @@ -263,12 +270,12 @@ impl BpeTrainer { /// Compute the initial alphabet and limit it if relevant fn compute_alphabet( &self, - wc: &HashMap, - w2id: &mut HashMap, - id2w: &mut Vec, + wc: &AHashMap, + w2id: &mut AHashMap, + id2w: &mut Vec, ) { // Compute the alphabet from seen words - let mut alphabet: HashMap = HashMap::new(); + let mut alphabet: AHashMap = AHashMap::new(); for (word, count) in wc { for c in word.chars() { alphabet @@ -312,19 +319,26 @@ impl BpeTrainer { kept.sort_unstable_by_key(|k| (*k.0) as u32); kept.into_iter().for_each(|(c, _)| { let s = c.to_string(); + /* if !w2id.contains_key(&s) { id2w.push(s.clone()); w2id.insert(s, (id2w.len() - 1) as u32); } + */ + // u64 hash version + if !w2id.contains_key(&CompactString::from(&s)) { + id2w.push(CompactString::from(&s)); + w2id.insert(CompactString::from(&s), (id2w.len() - 1) as u32); + } }); } /// Tokenize words and add subwords to the vocabulary when relevant fn tokenize_words( &self, - wc: &HashMap, - w2id: &mut HashMap, - id2w: &mut Vec, + wc: &AHashMap, + w2id: &mut AHashMap, + id2w: &mut Vec, p: &Option, ) -> (Vec, Vec) { let mut words: Vec = Vec::with_capacity(wc.len()); @@ -336,7 +350,7 @@ impl BpeTrainer { for (is_first, is_last, c) in word.chars().with_first_and_last() { let mut s = c.to_string(); - if w2id.contains_key(&s) { + if w2id.contains_key(&CompactString::from(&s)) { // Found the initial char in the authorized alphabet // Add the `continuing_subword_prefix` if relevant @@ -353,11 +367,11 @@ impl BpeTrainer { } // Insert the new formed string if necessary - if !w2id.contains_key(&s) { - id2w.push(s.clone()); - w2id.insert(s.clone(), (id2w.len() - 1) as u32); + if !w2id.contains_key(&CompactString::from(&s)) { + id2w.push(CompactString::from(&s)); + w2id.insert(CompactString::from(&s), (id2w.len() - 1) as u32); } - current_word.add(w2id[&s], 1); // We do not care about the len here + current_word.add(w2id[&CompactString::from(&s)], 1); // We do not care about the len here } } words.push(current_word); @@ -375,13 +389,13 @@ impl BpeTrainer { words: &[Word], counts: &[u64], p: &Option, - ) -> (HashMap, HashMap>) { + ) -> (AHashMap, AHashMap>) { words .maybe_par_iter() .enumerate() .map(|(i, word)| { - let mut pair_counts = HashMap::new(); - let mut where_to_update: HashMap> = HashMap::new(); + let mut pair_counts = AHashMap::new(); + let mut where_to_update: AHashMap> = AHashMap::new(); for window in word.get_chars().windows(2) { let cur_pair: Pair = (window[0], window[1]); @@ -399,7 +413,7 @@ impl BpeTrainer { h.insert(i); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = AHashSet::new(); h.insert(i); h }); @@ -413,7 +427,7 @@ impl BpeTrainer { (pair_counts, where_to_update) }) .reduce( - || (HashMap::new(), HashMap::new()), + || (AHashMap::new(), AHashMap::new()), |(mut pair_counts, mut where_to_update), (pc, wtu)| { for (k, v) in pc { pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v); @@ -431,11 +445,11 @@ impl BpeTrainer { pub fn do_train( &self, - word_counts: &HashMap, + word_counts: &AHashMap, model: &mut BPE, ) -> Result> { - let mut word_to_id: HashMap = HashMap::with_capacity(self.vocab_size); - let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); + let mut word_to_id: AHashMap = AHashMap::with_capacity(self.vocab_size); + let mut id_to_word: Vec = Vec::with_capacity(self.vocab_size); let max_token_length: usize = self.max_token_length.unwrap_or(usize::MAX); let progress = self.setup_progress(); @@ -464,7 +478,7 @@ impl BpeTrainer { self.update_progress(&progress, words.len(), "Count pairs"); let (mut pair_counts, mut where_to_update) = self.count_pairs(&words, &counts, &progress); // Insert them in the queue - let mut queue = BinaryHeap::with_capacity(pair_counts.len()); + let mut queue = OctonaryHeap::with_capacity(pair_counts.len()); where_to_update.drain().for_each(|(pair, pos)| { let count = pair_counts[&pair]; if count > 0 { @@ -510,7 +524,7 @@ impl BpeTrainer { if let Some(prefix) = &self.continuing_subword_prefix { if part_b.starts_with(prefix) { let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum(); - part_b = part_b[prefix_byte_len..].to_string(); + part_b = CompactString::from(&part_b[prefix_byte_len..]); } } let new_token = format!("{part_a}{part_b}"); @@ -520,19 +534,19 @@ impl BpeTrainer { // Insert new token if it does not already exist let new_token_id = word_to_id - .get(&new_token) + .get(&CompactString::from(&new_token)) .copied() .unwrap_or(id_to_word.len() as u32); - if !word_to_id.contains_key(&new_token) { - id_to_word.push(new_token.clone()); - word_to_id.insert(new_token.clone(), new_token_id); + if !word_to_id.contains_key(&CompactString::from(&new_token)) { + id_to_word.push(CompactString::from(&new_token)); + word_to_id.insert(CompactString::from(&new_token), new_token_id); } merges.push((top.pair, new_token_id)); // Merge the new pair in every words // Safety: This is just a type assertion, the code below may no longer be safe // if the type of `pos` changes - let pos: &HashSet = &top.pos; + let pos: &AHashSet = &top.pos; let words_len = words.len(); struct WordPtr(*mut Word); @@ -544,11 +558,8 @@ impl BpeTrainer { let changes = pos .maybe_par_iter() .flat_map(|&i| { - // Safety: - // We are producing a valid pointer since we are indexing in bounds - // - // We can access each `word` here in parallel because each position - // can be there only once (pos is a HashSet). + // We can merge each of these words in parallel here because each position + // can be there only once (AHashSet). So this is safe. unsafe { assert!(i < words_len); // This is words[i], but avoids needing to go through &T (which triggers UB) @@ -577,7 +588,7 @@ impl BpeTrainer { h.insert(iw); }) .or_insert_with(|| { - let mut h = HashSet::new(); + let mut h = AHashSet::new(); h.insert(iw); h }); @@ -601,7 +612,12 @@ impl BpeTrainer { self.finalize_progress(&progress, merges.len()); // Transfer new vocab & options to model - model.vocab = word_to_id; + //model.vocab = word_to_id; + model.vocab = word_to_id + .into_iter() + // we have to look up the string in id_to_word because the key in word_to_id is a hash + .map(|(_key, val)| (id_to_word[val as usize].to_string(), val)) + .collect(); model.vocab_r = model .vocab .iter() @@ -647,18 +663,20 @@ impl Trainer for BpeTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = AHashMap::new(); for word in words { - map.entry(word).and_modify(|c| *c += 1).or_insert(1); + map.entry(CompactString::from(word)) + .and_modify(|c| *c += 1) + .or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(AHashMap::new()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -676,11 +694,12 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { use super::{BpeTrainer, Pair, BPE}; - use std::collections::HashMap; + use ahash::AHashMap; + use compact_str::CompactString; #[test] fn test_train() { - let word_counts: HashMap = [ + let word_counts: AHashMap = [ ("roses".into(), 1), ("are".into(), 2), ("red".into(), 1), @@ -705,7 +724,7 @@ mod tests { // Vocab should contain all of the characters from the `word_counts` mapping // as well as three merges: 're', 'are', and 'is'. - let expected_vocab: HashMap = [ + let expected_vocab: AHashMap = [ ("-".into(), 0), ("2".into(), 1), ("B".into(), 2), @@ -741,7 +760,7 @@ mod tests { // where 'rank' determines the order in which this merge will be applied during // tokenization, and 'id' is the vocab id of the symbol resulting from merging // the pair of symbols in the corresponding key. - let expected_merges: HashMap = [ + let expected_merges: AHashMap = [ ((17, 11), (0, 22)), // 'r' + 'e' -> 're' ((8, 22), (1, 23)), // 'a' + 're' -> 'are' ((13, 18), (2, 24)), // 'i' + 's' -> 'is' @@ -759,7 +778,7 @@ mod tests { */ let max_token_length = 16; - let long_word_counts: HashMap = [ + let long_word_counts: AHashMap = [ ("singlelongtokenwithoutcasechange", 2), ("singleLongTokenWithCamelCaseChange", 2), ("Longsingletokenwithpunctu@t!onwithin", 2), @@ -774,7 +793,7 @@ mod tests { ("GPT-2", 2), ] .iter() - .map(|(key, value)| (key.to_string(), *value)) + .map(|(key, value)| (CompactString::from(key.to_string()), *value)) .collect(); let trainer = BpeTrainer::builder() .max_token_length(Some(max_token_length)) @@ -799,7 +818,7 @@ mod tests { // directly compares tokens with known expected values. // maybe unstable depending on specific settings or changes. */ - let long_word_counts: HashMap = [ + let long_word_counts: AHashMap = [ ("sin", 2), ("Sin", 2), ("Lon", 2), @@ -814,7 +833,7 @@ mod tests { ("GP", 2), ] .iter() - .map(|(key, value)| (key.to_string(), *value)) + .map(|(key, value)| (CompactString::from(key.to_string()), *value)) .collect(); let trainer = BpeTrainer::builder() .max_token_length(Some(2)) @@ -823,8 +842,8 @@ mod tests { .build(); let mut model = BPE::default(); trainer.do_train(&long_word_counts, &mut model).unwrap(); - let trained_vocab: HashMap = model.get_vocab(); - let expected_vocab: HashMap = [ + let trained_vocab: AHashMap = model.get_vocab().into_iter().collect(); + let expected_vocab: AHashMap = [ ("短", 12), ("n", 6), ("i", 5), diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 9d09fa2af..7bf2dee56 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,7 +1,8 @@ use super::Pair; +use ahash::AHashMap; +use dary_heap::QuaternaryHeap; use rand::{rng, Rng}; use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashMap}; #[derive(Debug, Eq)] struct Merge { @@ -158,8 +159,8 @@ impl Word { changes } - pub(super) fn merge_all(&mut self, merges: &HashMap, dropout: Option) { - let mut queue = BinaryHeap::with_capacity(self.symbols.len()); + pub(super) fn merge_all(&mut self, merges: &AHashMap, dropout: Option) { + let mut queue = QuaternaryHeap::with_capacity(self.symbols.len()); let mut skip = Vec::with_capacity(queue.len()); queue.extend( diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index 4e5419bad..932bc598d 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -5,6 +5,7 @@ pub mod unigram; pub mod wordlevel; pub mod wordpiece; +use ahash::AHashMap; use std::collections::HashMap; use std::path::{Path, PathBuf}; @@ -19,11 +20,11 @@ use crate::{AddedToken, Model, Result, Token, Trainer}; /// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order /// of token ID, smallest to largest. struct OrderedVocabIter<'a> { - vocab_r: &'a HashMap, + vocab_r: &'a AHashMap, } impl<'a> OrderedVocabIter<'a> { - fn new(vocab_r: &'a HashMap) -> Self { + fn new(vocab_r: &'a AHashMap) -> Self { Self { vocab_r } } } @@ -301,8 +302,8 @@ mod tests { #[test] fn incomplete_ordered_vocab() { - let vocab_r: HashMap = - HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]); + let vocab_r: AHashMap = + AHashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]); let ordered = OrderedVocabIter::new(&vocab_r); diff --git a/tokenizers/src/models/unigram/lattice.rs b/tokenizers/src/models/unigram/lattice.rs index 1302019e1..2897bb376 100644 --- a/tokenizers/src/models/unigram/lattice.rs +++ b/tokenizers/src/models/unigram/lattice.rs @@ -1,13 +1,13 @@ +use dary_heap::QuaternaryHeap; use rand::distr::weighted::WeightedIndex; use rand::{prelude::*, rng}; use std::cell::RefCell; use std::cmp::{min, Ordering}; -use std::collections::BinaryHeap; use std::rc::Rc; type NodeRef = Rc>; type HypothesisRef = Rc>; -type Agenda = BinaryHeap; +type Agenda = QuaternaryHeap; struct Hypothesis { node_ref: NodeRef, @@ -240,7 +240,7 @@ impl<'a> Lattice<'a> { 1 => vec![self.viterbi()], _ => { // let k_reserved_hypothesis_size = 512; - let mut agenda: Agenda = BinaryHeap::new(); + let mut agenda: Agenda = QuaternaryHeap::new(); let mut hypotheses: Vec> = vec![]; let eos = self.eos_node(); let score = eos.borrow().score; @@ -282,7 +282,7 @@ impl<'a> Lattice<'a> { let k_max_agenda_size = 100_000; let k_min_agenda_size = 512; if agenda.len() > k_max_agenda_size { - let mut new_agenda = BinaryHeap::new(); + let mut new_agenda = QuaternaryHeap::new(); let len = min(k_min_agenda_size, n * 10); for _i in 0..len { new_agenda.push(agenda.pop().unwrap()); diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index fd498c822..7b876ec9d 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -5,13 +5,14 @@ use super::{ }; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, MAX_LENGTH}; - use std::collections::HashMap; + +use ahash::AHashMap; use std::convert::TryInto; use std::fs::read_to_string; use std::path::{Path, PathBuf}; -type TokenMap = HashMap; +type TokenMap = AHashMap; type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. @@ -98,7 +99,7 @@ impl Unigram { byte_fallback: bool, ) -> Result { let n = vocab.len(); - let mut token_to_ids: TokenMap = HashMap::new(); + let mut token_to_ids: TokenMap = AHashMap::new(); let mut builder = TrieBuilder::default(); if let Some(unk_id) = unk_id { @@ -416,7 +417,7 @@ impl Model for Unigram { type Trainer = UnigramTrainer; fn get_vocab(&self) -> HashMap { - self.token_to_ids.clone() + self.token_to_ids.clone().into_iter().collect() } fn get_vocab_size(&self) -> usize { diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index d6d2830fd..920dee525 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -2,10 +2,10 @@ use crate::models::unigram::{lattice::Lattice, model::Unigram}; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use ahash::{AHashMap, AHashSet}; use log::debug; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; -use std::collections::{HashMap, HashSet}; use std::convert::TryInto; // A token and a score @@ -57,8 +57,8 @@ pub struct UnigramTrainer { pub shrinking_factor: f64, #[builder(default = "vec![]")] pub special_tokens: Vec, - #[builder(default = "HashSet::new()")] - pub initial_alphabet: HashSet, + #[builder(default = "AHashSet::new()")] + pub initial_alphabet: AHashSet, #[builder(default = "None")] pub unk_token: Option, @@ -67,8 +67,8 @@ pub struct UnigramTrainer { pub max_piece_length: usize, #[builder(default = "1_000_000")] seed_size: usize, - #[builder(default = "HashMap::new()")] - words: HashMap, + #[builder(default = "AHashMap::new()")] + words: AHashMap, } impl Default for UnigramTrainer { @@ -110,17 +110,17 @@ impl UnigramTrainer { true } - fn finalize(&self, model: Unigram, required_chars: HashSet) -> Result { + fn finalize(&self, model: Unigram, required_chars: AHashSet) -> Result { let mut min_score_penalty = 0.0; let min_score_penalty_delta = 0.0001; let mut pieces: Vec<(String, f64)> = vec![]; - let mut inserted: HashSet = HashSet::new(); + let mut inserted: AHashSet = AHashSet::new(); // We don't want to include the that was used to train inserted.insert("".into()); - let existing_pieces: HashMap = model.iter().cloned().collect(); + let existing_pieces: AHashMap = model.iter().cloned().collect(); for c in required_chars { if let Some(t) = existing_pieces.get(&c) { inserted.insert(c.clone()); @@ -185,7 +185,7 @@ impl UnigramTrainer { ) } - fn required_chars(&self, word_counts: &[Sentence]) -> HashSet { + fn required_chars(&self, word_counts: &[Sentence]) -> AHashSet { word_counts .iter() .flat_map(|(s, _count)| s.chars()) @@ -205,7 +205,7 @@ impl UnigramTrainer { .sum::() + sentences.len(); let mut flat_string = String::with_capacity(total); - let mut all_chars: HashMap = HashMap::new(); + let mut all_chars: AHashMap = AHashMap::new(); let c_sentence_boundary = '\0'; let k_sentence_boundary = '\0'.to_string(); for (string, n) in sentences { @@ -631,18 +631,18 @@ impl Trainer for UnigramTrainer { S: AsRef + Send, F: Fn(&str) -> Result> + Sync, { - let words: Result> = iterator + let words: Result> = iterator .maybe_par_bridge() .map(|sequence| { let words = process(sequence.as_ref())?; - let mut map = HashMap::new(); + let mut map = AHashMap::new(); for word in words { map.entry(word).and_modify(|c| *c += 1).or_insert(1); } Ok(map) }) .reduce( - || Ok(HashMap::new()), + || Ok(AHashMap::new()), |acc, ws| { let mut acc = acc?; for (k, v) in ws? { @@ -716,7 +716,7 @@ mod tests { fn test_initial_alphabet() { let trainer = UnigramTrainerBuilder::default() .show_progress(false) - .initial_alphabet(HashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f'])) + .initial_alphabet(AHashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f'])) .build() .unwrap(); @@ -727,7 +727,7 @@ mod tests { vec!["こ", "ん", "に", "ち", "は", "友", "達", "a", "b", "c", "d", "e", "f"] .into_iter() .map(|s| s.to_owned()) - .collect::>() + .collect::>() ); } diff --git a/tokenizers/src/models/unigram/trie.rs b/tokenizers/src/models/unigram/trie.rs index 2f94b1766..dd06f7f02 100644 --- a/tokenizers/src/models/unigram/trie.rs +++ b/tokenizers/src/models/unigram/trie.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use ahash::AHashMap; use std::hash::Hash; #[derive(Default)] @@ -78,14 +78,14 @@ impl