Skip to content

Draft backtrack #1712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/node/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ napi = "2"
napi-derive = "2"
serde = { version = "1.0.163", features = ["derive"] }
tokenizers = { path = "../../tokenizers/" }
ahash = { version = "0.8.11", features = ["serde"] }

[build-dependencies]
napi-build = "2"
Expand Down
20 changes: 14 additions & 6 deletions bindings/node/src/models.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use crate::arc_rwlock_serde;
use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask};
use crate::trainers::Trainer;
use ahash::AHashMap;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use tokenizers as tk;
use tokenizers::models::bpe::{BpeBuilder, Merges, Vocab};
use tokenizers::models::bpe::{BpeBuilder, Merges};
use tokenizers::models::wordlevel::WordLevelBuilder;
use tokenizers::models::wordpiece::WordPieceBuilder;

Expand Down Expand Up @@ -44,8 +45,13 @@ impl Bpe {
}

#[napi(factory, ts_return_type = "Model")]
pub fn init(vocab: Vocab, merges: Merges, options: Option<BpeOptions>) -> Result<Model> {
pub fn init(
vocab: HashMap<String, u32>,
merges: Merges,
options: Option<BpeOptions>,
) -> Result<Model> {
let options = options.unwrap_or_default();
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
let mut builder = tk::models::bpe::BPE::builder().vocab_and_merges(vocab, merges);
builder = options.apply_to_bpe_builder(builder);
let model = builder
Expand Down Expand Up @@ -206,10 +212,11 @@ pub struct WordPiece {}
#[napi]
impl WordPiece {
#[napi(factory, ts_return_type = "Model")]
pub fn init(vocab: Vocab, options: Option<WordPieceOptions>) -> Result<Model> {
pub fn init(vocab: HashMap<String, u32>, options: Option<WordPieceOptions>) -> Result<Model> {
let options = options.unwrap_or_default();

let mut builder = tk::models::wordpiece::WordPiece::builder().vocab(vocab);
let mut builder = tk::models::wordpiece::WordPiece::builder()
.vocab(vocab.into_iter().collect::<AHashMap<_, _>>());
builder = options.apply_to_wordpiece_builder(builder);
let model = builder
.build()
Expand Down Expand Up @@ -263,9 +270,10 @@ pub struct WordLevel {}
#[napi]
impl WordLevel {
#[napi(factory, ts_return_type = "Model")]
pub fn init(vocab: Vocab, options: Option<WordLevelOptions>) -> Result<Model> {
pub fn init(vocab: HashMap<String, u32>, options: Option<WordLevelOptions>) -> Result<Model> {
let options = options.unwrap_or_default();
let mut builder = tk::models::wordlevel::WordLevel::builder().vocab(vocab);
let mut builder =
tk::models::wordlevel::WordLevel::builder().vocab(vocab.into_iter().collect());
builder = options.apply_to_wordlevel_builder(builder);
let model = builder
.build()
Expand Down
1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pyo3 = { version = "0.25", features = ["abi3", "abi3-py39", "py-clone"] }
numpy = "0.25"
ndarray = "0.16"
itertools = "0.14"
ahash = { version = "0.8.11", features = ["serde"] }

[dependencies.tokenizers]
path = "../../tokenizers"
Expand Down
88 changes: 88 additions & 0 deletions bindings/python/benches/test_backtrack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import argparse
import datetime
from datasets import load_dataset
from tokenizers import Tokenizer
from typing import Tuple

MODEL_ID = "meta-llama/Meta-Llama-3.1-8B"
DATASET = "facebook/xnli"
DATASET_CONFIG = "all_languages"
DEFAULT_THREADS = [2**i for i in range(8) if 2**i <= os.cpu_count()]


def format_byte_size(num_bytes: int) -> Tuple[str, str]:
"""Convert bytes to a human-readable format (KB, MB, GB)."""
num_bytes_f = float(num_bytes)
for unit in ["B", "KB", "MB", "GB", "TB"]:
if num_bytes_f < 1024:
return f"{num_bytes_f:.2f} {unit}", unit
num_bytes_f /= 1024
return f"{num_bytes_f:.2f} PB", "PB"


def test(model: str, dataset: str, dataset_config: str):
dataset_xnli = load_dataset(dataset, dataset_config)
tokenizer = Tokenizer.from_pretrained(model)
tokenizer2 = Tokenizer.from_pretrained(model)
tokenizer2.enable_backtrack()

for easy in ["1880", " cream"]:
encoded = tokenizer.encode(easy)
encoded2 = tokenizer2.encode(easy)
if encoded.ids != encoded2.ids:
import ipdb

ipdb.set_trace()
assert encoded.ids == encoded2.ids

sentences = []
en_sentences = []
for _i, item in enumerate(dataset_xnli["train"]):
# sentence = item["premise"]["en"]
# sentences.append(sentence)
for lang, sentence in item["premise"].items():
if lang == "en":
en_sentences.append(sentence)
sentences.append(sentence)
sentences = en_sentences + sentences

start = datetime.datetime.now()
encoded = tokenizer.encode_batch_fast(sentences)
print(f"Took {datetime.datetime.now() - start}")

start = datetime.datetime.now()
encoded2 = tokenizer2.encode_batch_fast(sentences)
print(f"Took {datetime.datetime.now() - start}")

assert len(encoded) == len(encoded2)
assert len(encoded) == len(sentences)
total = 0
correct = 0
for enc, enc2, sentence in zip(encoded, encoded2, sentences):
# if enc.ids != enc2.ids:
# print(enc.ids)
# print(enc2.ids)
if enc.ids == enc2.ids:
correct += 1
total += 1
assert enc.ids == enc2.ids, f"{enc.ids} != {enc2.ids} (Source: {sentence}"
print(f"{correct} / {total} ({correct / total * 100:.2f}%%)")
# print("All good !")


def main():
parser = argparse.ArgumentParser(
prog="bench_tokenizer",
description="Getting a feel for speed when tokenizing",
)
parser.add_argument("-m", "--model", default=MODEL_ID, type=str)
parser.add_argument("-d", "--dataset", default=DATASET, type=str)
parser.add_argument("-ds", "--dataset-config", default=DATASET_CONFIG, type=str)
args = parser.parse_args()
test(args.model, args.dataset, args.dataset_config)


# Call the function to run the benchmark
if __name__ == "__main__":
main()
35 changes: 24 additions & 11 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use std::sync::{Arc, RwLock};

use crate::token::PyToken;
use crate::trainers::PyTrainer;
use ahash::AHashMap;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::{Deserialize, Serialize};
use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
use tk::models::bpe::{BpeBuilder, Merges, BPE};
use tk::models::unigram::Unigram;
use tk::models::wordlevel::WordLevel;
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
Expand Down Expand Up @@ -347,9 +348,10 @@ macro_rules! setter {

#[derive(FromPyObject)]
enum PyVocab {
Vocab(Vocab),
Vocab(HashMap<String, u32>),
Filename(String),
}

#[derive(FromPyObject)]
enum PyMerges {
Merges(Merges),
Expand Down Expand Up @@ -454,6 +456,7 @@ impl PyBPE {
if let (Some(vocab), Some(merges)) = (vocab, merges) {
match (vocab, merges) {
(PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => {
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
builder = builder.vocab_and_merges(vocab, merges);
}
(PyVocab::Filename(vocab_filename), PyMerges::Filename(merges_filename)) => {
Expand Down Expand Up @@ -494,13 +497,15 @@ impl PyBPE {
/// The vocabulary and merges loaded into memory
#[staticmethod]
#[pyo3(text_signature = "(self, vocab, merges)")]
fn read_file(vocab: &str, merges: &str) -> PyResult<(Vocab, Merges)> {
BPE::read_file(vocab, merges).map_err(|e| {
fn read_file(vocab: &str, merges: &str) -> PyResult<(HashMap<String, u32>, Merges)> {
let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while reading vocab & merges files: {}",
e
))
})
})?;
let vocab = vocab.into_iter().collect();
Ok((vocab, merges))
}

/// Instantiate a BPE model from the given files.
Expand Down Expand Up @@ -536,6 +541,7 @@ impl PyBPE {
let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading BPE files: {}", e))
})?;
let vocab = vocab.into_iter().collect();
Py::new(
py,
PyBPE::new(
Expand Down Expand Up @@ -668,6 +674,7 @@ impl PyWordPiece {
if let Some(vocab) = vocab {
match vocab {
PyVocab::Vocab(vocab) => {
let vocab: AHashMap<_, _> = vocab.into_iter().collect();
builder = builder.vocab(vocab);
}
PyVocab::Filename(vocab_filename) => {
Expand Down Expand Up @@ -699,10 +706,11 @@ impl PyWordPiece {
/// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
#[staticmethod]
#[pyo3(text_signature = "(vocab)")]
fn read_file(vocab: &str) -> PyResult<Vocab> {
WordPiece::read_file(vocab).map_err(|e| {
fn read_file(vocab: &str) -> PyResult<HashMap<String, u32>> {
let vocab = WordPiece::read_file(vocab).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e))
})
})?;
Ok(vocab.into_iter().collect())
}

/// Instantiate a WordPiece model from the given file
Expand Down Expand Up @@ -734,6 +742,7 @@ impl PyWordPiece {
let vocab = WordPiece::read_file(vocab).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e))
})?;
let vocab = vocab.into_iter().collect();
Py::new(
py,
PyWordPiece::new(py, Some(PyVocab::Vocab(vocab)), kwargs)?,
Expand Down Expand Up @@ -778,6 +787,7 @@ impl PyWordLevel {
if let Some(vocab) = vocab {
match vocab {
PyVocab::Vocab(vocab) => {
let vocab = vocab.into_iter().collect();
builder = builder.vocab(vocab);
}
PyVocab::Filename(vocab_filename) => {
Expand Down Expand Up @@ -818,10 +828,12 @@ impl PyWordLevel {
/// :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
#[staticmethod]
#[pyo3(text_signature = "(vocab)")]
fn read_file(vocab: &str) -> PyResult<Vocab> {
WordLevel::read_file(vocab).map_err(|e| {
fn read_file(vocab: &str) -> PyResult<HashMap<String, u32>> {
let vocab = WordLevel::read_file(vocab).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e))
})
})?;
let vocab: HashMap<_, _> = vocab.into_iter().collect();
Ok(vocab)
}

/// Instantiate a WordLevel model from the given file
Expand Down Expand Up @@ -853,6 +865,7 @@ impl PyWordLevel {
let vocab = WordLevel::read_file(vocab).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading WordLevel file: {}", e))
})?;
let vocab = vocab.into_iter().collect();
Py::new(
py,
PyWordLevel::new(py, Some(PyVocab::Vocab(vocab)), unk_token)?,
Expand Down
15 changes: 15 additions & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use serde::Serialize;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};
use tk::pre_tokenizers::byte_level::ByteLevel;
use tk::ModelWrapper;

use numpy::{npyffi, PyArray1, PyArrayMethods};
use pyo3::class::basic::CompareOp;
Expand Down Expand Up @@ -1118,6 +1120,19 @@ impl PyTokenizer {
.into()
})
}
///
#[pyo3(signature = ())]
#[pyo3(text_signature = "(self)")]
fn enable_backtrack(&mut self) -> PyResult<()> {
// self.tokenizer.with_pre_tokenizer(None::<ByteLevel>);
let model = self.tokenizer.get_model();
let mut model = model.model.write().unwrap();
let ModelWrapper::BPE(ref mut model) = *model else {
todo!();
};
model.enable_backtrack();
Ok(())
}

/// Decode the given list of ids back to a string
///
Expand Down
Loading