diff --git a/.gitignore b/.gitignore index 68cdf7f..65cb3f2 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,12 @@ htmlcov Cargo.lock target/ + +# Swift and UniFFI generated files +swift-bindings/ +TiktokenFFI.xcframework/ +.swiftpm/ +.build/ +xcuserdata/ +DerivedData/ +*.xcodeproj diff --git a/Cargo.toml b/Cargo.toml index 4177365..14ff502 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,13 +5,15 @@ edition = "2024" [lib] name = "tiktoken" -crate-type = ["cdylib", "rlib"] +crate-type = ["cdylib", "staticlib", "rlib"] + [features] -default = [] +default = ["python"] python = [ "pyo3", ] +uniffi = ["dep:uniffi", "uniffi_bindgen", "camino", "thiserror", "base64"] [dependencies] pyo3 = { version = "0.24.1", default-features = false, features = [ @@ -24,3 +26,17 @@ fancy-regex = "0.13.0" regex = "1.10.3" rustc-hash = "1.1.0" bstr = "1.5.0" + +# UniFFI dependencies (optional) +uniffi = { version = "0.29", features = ["build"], optional = true } +thiserror = { version = "1.0", optional = true } +base64 = { version = "0.22", optional = true } +uniffi_bindgen = { version = "0.29", optional = true } +camino = { version = "1.1", optional = true } + +[build-dependencies] +uniffi = { version = "0.29", features = ["bindgen"] } +uniffi_build = { version = "0.29" } +uniffi_bindgen = { version = "0.29" } +camino = { version = "1.1" } + diff --git a/README.md b/README.md index 4f36c53..0d3ab8f 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,6 @@ The tokeniser API is documented in `tiktoken/core.py`. Example code using `tiktoken` can be found in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb). - ## Performance `tiktoken` is between 3-6x faster than a comparable open source tokeniser: diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..aa312d8 --- /dev/null +++ b/build.rs @@ -0,0 +1,4 @@ +fn main() { + #[cfg(feature = "uniffi")] + uniffi_build::generate_scaffolding("src/tiktoken.udl").unwrap(); +} \ No newline at end of file diff --git a/build_xcframework.sh b/build_xcframework.sh new file mode 100755 index 0000000..757d6c3 --- /dev/null +++ b/build_xcframework.sh @@ -0,0 +1,322 @@ +#!/bin/bash +set -e + +echo "๐Ÿš€ Building Multi-Platform XCFramework for tiktoken..." +echo "" + +# Get the script directory +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR" + +echo "๐Ÿ“ Working directory: $(pwd)" +echo "" + +# Check for required tools +echo "๐Ÿ” Checking required tools..." +if ! command -v cargo &> /dev/null; then + echo "โŒ cargo not found. Please install Rust." + exit 1 +else + echo "โœ… cargo found: $(cargo --version)" +fi + +if ! command -v xcodebuild &> /dev/null; then + echo "โŒ xcodebuild not found. Please install Xcode." + exit 1 +else + echo "โœ… xcodebuild found: $(xcodebuild -version | head -n1)" +fi + +if ! command -v lipo &> /dev/null; then + echo "โŒ lipo not found. Please install Xcode Command Line Tools." + exit 1 +else + echo "โœ… lipo found" +fi + +# Clean build artifacts to ensure fresh build +echo "" +echo "๐Ÿงน Cleaning previous build artifacts..." +cargo clean + +# First, test that we can build with uniffi feature +echo "" +echo "๐Ÿงช Testing uniffi build..." +cargo build --release --no-default-features --features uniffi || { + echo "โŒ Failed to build with uniffi feature" + echo "" + echo "๐Ÿ“ Build output:" + cargo build --release --no-default-features --features uniffi 2>&1 + exit 1 +} +echo "โœ… Uniffi build successful" + +# Generate the Swift bindings +echo "" +echo "๐Ÿ”ง Generating Swift bindings..." +mkdir -p swift-bindings + +# Use the installed uniffi-bindgen to generate Swift bindings +if [ -f "$HOME/.cargo/bin/uniffi-bindgen" ]; then + UNIFFI_BINDGEN="$HOME/.cargo/bin/uniffi-bindgen" + echo "โœ… Using uniffi-bindgen from cargo" +elif command -v uniffi-bindgen &> /dev/null; then + UNIFFI_BINDGEN="uniffi-bindgen" + echo "โœ… Using system uniffi-bindgen" +else + echo "โŒ uniffi-bindgen not found. Please install it with: cargo install uniffi_bindgen" + exit 1 +fi + +echo "๐Ÿ“ Running uniffi-bindgen..." +$UNIFFI_BINDGEN generate src/tiktoken.udl \ + --language swift \ + --out-dir swift-bindings \ + --config uniffi.toml || { + echo "โŒ Failed to generate Swift bindings" + exit 1 +} + +# Remove the old incorrect module map if it exists +rm -f swift-bindings/module.modulemap + +# Install required targets if not already installed +echo "" +echo "๐Ÿ“ฑ Checking and installing required Rust targets..." + +# Function to check and add target +add_target_if_needed() { + local target=$1 + if rustup target list --installed | grep -q "$target"; then + echo " โœ… $target already installed" + else + echo " ๐Ÿ“ฆ Installing $target..." + rustup target add "$target" || { + echo " โš ๏ธ Failed to install $target" + return 1 + } + fi + return 0 +} + +# Install all required targets +add_target_if_needed "aarch64-apple-ios" +add_target_if_needed "aarch64-apple-ios-sim" +add_target_if_needed "x86_64-apple-ios" +add_target_if_needed "aarch64-apple-darwin" +add_target_if_needed "x86_64-apple-darwin" + +# Build for all platforms +echo "" +echo "๐Ÿฆ€ Building Rust library for all Apple platforms..." + +# Set environment to handle cross-compilation without Python +export PYO3_NO_PYTHON=1 + +# Build for iOS arm64 +echo " ๐Ÿ“ฑ Building for iOS (arm64)..." +cargo build --release --no-default-features --features uniffi --target aarch64-apple-ios || { + echo " โŒ Failed to build for iOS arm64" + exit 1 +} + +# Build for iOS simulator (arm64 + x86_64) +echo " ๐Ÿ“ฑ Building for iOS Simulator (arm64)..." +cargo build --release --no-default-features --features uniffi --target aarch64-apple-ios-sim || { + echo " โŒ Failed to build for iOS Simulator arm64" + exit 1 +} + +echo " ๐Ÿ“ฑ Building for iOS Simulator (x86_64)..." +cargo build --release --no-default-features --features uniffi --target x86_64-apple-ios || { + echo " โŒ Failed to build for iOS Simulator x86_64" + exit 1 +} + +# Build for macOS (arm64 + x86_64) +echo " ๐Ÿ’ป Building for macOS (arm64)..." +cargo build --release --no-default-features --features uniffi --target aarch64-apple-darwin || { + echo " โŒ Failed to build for macOS arm64" + exit 1 +} + +echo " ๐Ÿ’ป Building for macOS (x86_64)..." +cargo build --release --no-default-features --features uniffi --target x86_64-apple-darwin || { + echo " โŒ Failed to build for macOS x86_64" + exit 1 +} + +# Swift bindings are already generated in swift-bindings directory + +# Create fat libraries +echo "" +echo "๐Ÿ”— Creating universal libraries..." + +# iOS Simulator universal binary +echo " ๐Ÿ“ฑ Creating iOS Simulator universal binary..." +mkdir -p target/universal-ios-sim +lipo -create \ + target/aarch64-apple-ios-sim/release/libtiktoken.a \ + target/x86_64-apple-ios/release/libtiktoken.a \ + -output target/universal-ios-sim/libtiktoken.a || { + echo " โŒ Failed to create iOS Simulator universal binary" + exit 1 +} +echo " โœ… iOS Simulator universal binary created" + +# macOS universal binary +echo " ๐Ÿ’ป Creating macOS universal binary..." +mkdir -p target/universal-macos +lipo -create \ + target/aarch64-apple-darwin/release/libtiktoken.a \ + target/x86_64-apple-darwin/release/libtiktoken.a \ + -output target/universal-macos/libtiktoken.a || { + echo " โŒ Failed to create macOS universal binary" + exit 1 +} +echo " โœ… macOS universal binary created" + +# Create module map for frameworks +echo "" +echo "๐Ÿ“ฆ Creating framework structure..." +cat > swift-bindings/module.modulemap << 'EOF' +framework module TiktokenFFI { + header "TiktokenFFI.h" + export * +} +EOF + +# Function to create framework +create_framework() { + local PLATFORM=$1 + local SDK=$2 + local LIB_PATH=$3 + local MIN_VERSION=$4 + + echo " ๐Ÿ“ฆ Creating framework for $PLATFORM..." + + local FRAMEWORK_DIR="build/$PLATFORM/TiktokenFFI.framework" + mkdir -p "$FRAMEWORK_DIR/Headers" + mkdir -p "$FRAMEWORK_DIR/Modules" + + # Copy header + cp swift-bindings/TiktokenFFI.h "$FRAMEWORK_DIR/Headers/" + + # Copy module map + cp swift-bindings/module.modulemap "$FRAMEWORK_DIR/Modules/module.modulemap" + + # Copy library + cp "$LIB_PATH" "$FRAMEWORK_DIR/TiktokenFFI" + + # Create Info.plist + cat > "$FRAMEWORK_DIR/Info.plist" << EOF + + + + + CFBundleDevelopmentRegion + en + CFBundleExecutable + TiktokenFFI + CFBundleIdentifier + com.tiktoken.TiktokenFFI + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + TiktokenFFI + CFBundlePackageType + FMWK + CFBundleShortVersionString + 1.0.0 + CFBundleSupportedPlatforms + + $SDK + + CFBundleVersion + 1 + MinimumOSVersion + $MIN_VERSION + + +EOF +} + +# Create build directory +mkdir -p build + +# Create frameworks +create_framework "ios" "iPhoneOS" "target/aarch64-apple-ios/release/libtiktoken.a" "13.0" +create_framework "ios-simulator" "iPhoneSimulator" "target/universal-ios-sim/libtiktoken.a" "13.0" +create_framework "macos" "MacOSX" "target/universal-macos/libtiktoken.a" "10.15" + +# Create XCFramework +echo "" +echo "๐Ÿ”ง Creating XCFramework..." + +# Verify frameworks exist +echo " ๐Ÿ” Verifying frameworks..." +for framework in "build/ios/TiktokenFFI.framework" "build/ios-simulator/TiktokenFFI.framework" "build/macos/TiktokenFFI.framework"; do + if [ -d "$framework" ]; then + echo " โœ… Found $framework" + else + echo " โŒ Missing $framework" + exit 1 + fi +done + +# Remove old XCFrameworks +echo " ๐Ÿงน Removing old XCFrameworks..." +rm -rf TiktokenFFI.xcframework +rm -rf TiktokenSwift/Sources/TiktokenFFI/TiktokenFFI.xcframework + +# Create the XCFramework +echo " ๐Ÿ—๏ธ Building XCFramework..." +xcodebuild -create-xcframework \ + -framework build/ios/TiktokenFFI.framework \ + -framework build/ios-simulator/TiktokenFFI.framework \ + -framework build/macos/TiktokenFFI.framework \ + -output TiktokenFFI.xcframework || { + echo " โŒ Failed to create XCFramework" + exit 1 +} +echo " โœ… XCFramework created successfully" + +# Copy to TiktokenSwift package in separate directory +TIKTOKEN_SWIFT_DIR="/Users/nicholasarner/Development/Active/TiktokenSwift" +if [ -d "$TIKTOKEN_SWIFT_DIR/Sources/TiktokenFFI" ]; then + echo "๐Ÿ“ฆ Copying XCFramework to TiktokenSwift package..." + cp -R TiktokenFFI.xcframework "$TIKTOKEN_SWIFT_DIR/Sources/TiktokenFFI/" + + # Update header if needed + if [ -f "swift-bindings/TiktokenFFI.h" ]; then + cp swift-bindings/TiktokenFFI.h "$TIKTOKEN_SWIFT_DIR/Sources/TiktokenFFI/include/" + fi + + # Update Swift file if needed + if [ -f "swift-bindings/TiktokenFFI.swift" ] && [ -f "$TIKTOKEN_SWIFT_DIR/Sources/TiktokenSwift/TiktokenFFI.swift" ]; then + cp swift-bindings/TiktokenFFI.swift "$TIKTOKEN_SWIFT_DIR/Sources/TiktokenSwift/TiktokenFFI.swift" + + # Fix imports + sed -i '' '/#if canImport(TiktokenFFI)/,/#endif/d' "$TIKTOKEN_SWIFT_DIR/Sources/TiktokenSwift/TiktokenFFI.swift" + sed -i '' '/^import Foundation$/a\ +import TiktokenFFI' "$TIKTOKEN_SWIFT_DIR/Sources/TiktokenSwift/TiktokenFFI.swift" + fi +fi + +# Clean up +rm -rf build +rm -rf swift-bindings + +echo "" +echo "โœ… Multi-platform XCFramework created successfully!" +echo "" +echo "๐ŸŽฏ Supported platforms:" +echo " - iOS devices (arm64)" +echo " - iOS Simulator (arm64, x86_64)" +echo " - macOS (arm64, x86_64)" +echo "" +echo "๐Ÿ“ฆ XCFramework locations:" +echo " - ./TiktokenFFI.xcframework" +if [ -d "$TIKTOKEN_SWIFT_DIR/Sources/TiktokenFFI/TiktokenFFI.xcframework" ]; then + echo " - $TIKTOKEN_SWIFT_DIR/Sources/TiktokenFFI/TiktokenFFI.xcframework" +fi \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 80f4acc..a879dda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,23 @@ -use std::borrow::Borrow; -use std::borrow::Cow; use std::collections::HashSet; -use std::num::NonZeroU64; use std::thread; use fancy_regex::Regex; #[cfg(feature = "python")] -use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyList, PyTuple}; +#[cfg(feature = "python")] +use pyo3::{exceptions, prelude::*, types::PyDict}; use rustc_hash::FxHashMap as HashMap; #[cfg(feature = "python")] mod py; +#[cfg(feature = "uniffi")] +pub mod uniffi_bindings; + +// UniFfiTag is required by the scaffolding at crate root +#[cfg(feature = "uniffi")] +pub struct UniFfiTag; + pub type Rank = u32; fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { @@ -50,16 +56,19 @@ fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, // If you have n parts and m merges, this does O(mn) work. // We could do something with a heap and do O(m log n) work. - // n is often very small so considerations like cache-locality outweigh the algorithmic - // complexity downsides of the `parts` vector. + // It's important that we're iterating over parts and not over ranks. + // The way we iterate here, we're iterating over parts (i.e. pieces of the text). + // If we iterated over ranks, we'd be iterating over the vocabulary. + // Given that vocabulary is >> parts in most cases, iterating over parts is faster. while min_rank.0 != Rank::MAX { let i = min_rank.1; // Update parts[i] and parts[i - 1] before removing parts[i + 1], since - // `parts.remove(i + 1)` will thrash the cache. + // `parts.remove(i + 1)` will invalidate them. + parts[i] = (parts[i].0, get_rank(&parts, i)); if i > 0 { - parts[i - 1].1 = get_rank(&parts, i - 1); + parts[i - 1] = (parts[i - 1].0, get_rank(&parts, i - 1)); } - parts[i].1 = get_rank(&parts, i); + parts.remove(i + 1); min_rank = (Rank::MAX, usize::MAX); @@ -102,70 +111,90 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, Rank>) -> V // between using the `regex` crate and using the `fancy_regex` crate. // // There is an important interaction between threading, `regex` and `fancy_regex`. -// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on -// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain -// old `regex`, we don't hit this, because `find_iter` has a different code path. -// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md -// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for -// each thread. -// -// Threading -// ========= -// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. -// So goodbye `rayon`! Let thread count etc be in control of our Python users. +// When using `fancy_regex`, we hit regex.find_at. It turns out that this causes contention on +// some mutable scratch space inside the regex. This absolutely kills performance. When using plain +// old `regex`, we don't hit this, because `regex` clones the regex for each thread. // -// Caching -// ======= -// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. -// Originally, we had one too! Without it, we were only vaguely faster than Python. -// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance -// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect -// multi-threaded performance even when I only had readers (maybed I messed something up?). -// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! -// These are exactly the set or merges that are likely to be hot. And now we don't have to think -// about interior mutability, memory use, or cloning. -// -// Hashing -// ======= -// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? -// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made -// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. +// Cloning the regex is expensive, so we rely on thread locals to avoid doing it too often. +// This is a bit tricky, but it's worth it for the performance boost. + +fn _get_regex(regex_str: &str) -> Result { + Regex::new(regex_str) +} + +#[derive(Debug, Clone)] +/// Tokenizer that doesn't have any special tokens and regex patterns +pub struct FakeTokenizer { + encoder: HashMap, Rank>, + decoder: HashMap>, +} + +impl FakeTokenizer { + pub fn new(encoder: HashMap, Rank>) -> Self { + let mut decoder = HashMap::default(); + for (k, v) in &encoder { + decoder.insert(*v, k.clone()); + } + + Self { encoder, decoder } + } + + pub fn encode(&self, text: &str) -> Vec { + match self.encoder.get(text.as_bytes()) { + Some(token) => vec![*token], + None => byte_pair_encode(text.as_bytes(), &self.encoder), + } + } + + pub fn decode(&self, tokens: Vec) -> Result { + let bytes = self.decode_bytes(tokens)?; + Ok(unsafe { String::from_utf8_unchecked(bytes) }) + } -struct FakeThreadId(NonZeroU64); + fn decode_bytes(&self, tokens: Vec) -> Result, DecodeError> { + let mut output = Vec::with_capacity(tokens.len() * 2); + for token in tokens { + let bytes = self.decoder.get(&token).ok_or(DecodeError { + message: format!("Invalid token: {}", token), + })?; + output.extend_from_slice(bytes); + } + Ok(output) + } +} fn hash_current_thread() -> usize { - // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter - // that works great for our use case of avoiding collisions in our array. Unfortunately, - // it's private. However, there are only so many ways you can layout a u64, so just transmute - // https://github.com/rust-lang/rust/issues/67939 - const _: [u8; 8] = [0; std::mem::size_of::()]; - const _: [u8; 8] = [0; std::mem::size_of::()]; - let x = unsafe { - std::mem::transmute::(thread::current().id()).0 - }; - u64::from(x) as usize + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let id = thread::current().id(); + let mut hasher = DefaultHasher::new(); + id.hash(&mut hasher); + hasher.finish() as usize } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DecodeKeyError { pub token: Rank, } -impl std::fmt::Display for DecodeKeyError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl fmt::Display for DecodeKeyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Invalid token for decoding: {}", self.token) } } impl std::error::Error for DecodeKeyError {} -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DecodeError { pub message: String, } -impl std::fmt::Display for DecodeError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +use std::fmt; + +impl fmt::Display for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Could not decode tokens: {}", self.message) } } @@ -214,7 +243,7 @@ impl CoreBPE { /// Decodes tokens into a list of bytes. /// /// The bytes are not gauranteed to be a valid utf-8 string. - fn decode_bytes(&self, tokens: &[Rank]) -> Result, DecodeKeyError> { + pub fn decode_bytes(&self, tokens: &[Rank]) -> Result, DecodeKeyError> { let mut ret = Vec::with_capacity(tokens.len() * 2); for &token in tokens { let token_bytes = match self.decoder.get(&token) { @@ -236,10 +265,11 @@ impl CoreBPE { let mut ret = vec![]; for mat in regex.find_iter(text) { let piece = mat.unwrap().as_str().as_bytes(); - match self.encoder.get(piece) { - Some(token) => ret.push(*token), - None => ret.extend(&byte_pair_encode(piece, &self.encoder)), + if let Some(token) = self.encoder.get(piece) { + ret.push(*token); + continue; } + ret.extend(&byte_pair_encode(piece, &self.encoder)); } ret } @@ -306,7 +336,7 @@ impl CoreBPE { } None => break, } - } + }; // last_piece_token_len is how many tokens came from the last regex split. This is used // for determining unstable tokens, since you can't merge across (stable) regex splits @@ -333,7 +363,7 @@ impl CoreBPE { token_bytes .iter() .rev() - .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) + .all(|&b| [b' ', b'\n', b'\r', b'\t'].contains(&b)) }) .unwrap_or(false) }; @@ -352,7 +382,7 @@ impl CoreBPE { (tokens, last_piece_token_len) } - pub fn _encode_unstable_native( + fn _encode_unstable_native( &self, text: &str, allowed_special: &HashSet<&str>, @@ -383,62 +413,29 @@ impl CoreBPE { // This is the easy bit. Just find all single tokens that start with unstable_bytes // (including tokens that exactly match unstable_bytes) // Separating this from the loop below helps with performance in a common case. - let mut point = self - .sorted_token_bytes - .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); - while point < self.sorted_token_bytes.len() - && self.sorted_token_bytes[point].starts_with(&unstable_bytes) - { - completions.insert(vec![ - self.encoder[self.sorted_token_bytes[point].as_slice()], - ]); - point += 1; - } - - // Now apply even more brute force. At every (other) possible position for the straddling - // token, concatenate additional bytes from that token (if any) to unstable_bytes, - // and retokenise the whole thing and see what we get. - for i in 1..unstable_bytes.len() { - let prefix = &unstable_bytes[..i]; - let suffix = &unstable_bytes[i..]; - let mut point = self - .sorted_token_bytes - .partition_point(|x| x.as_slice() < suffix); - // TODO: Perf optimisation if suffix starts with " "? - while point < self.sorted_token_bytes.len() - && self.sorted_token_bytes[point].starts_with(suffix) - { - let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); - let encoded = match std::str::from_utf8(&possibility) { - // Morally, this is byte_pair_encode(&possibility, &self.encoder) - // But we might have introduced a regex split which would prevent merges. - // (particularly possible in the presence of unstable regex splits) - // So convert to UTF-8 and do regex splitting. - // E.g. with cl100k_base " !" gets split to " " + " !", - // but byte_pair_encode(" !") != byte_pair_encode(" ") - Ok(s) => self.encode_ordinary(s), - - // Technically, whether or not this arm is correct depends on whether there - // would be a regex split before the UTF-8 truncation point. - // Probably niche enough that no one will ever notice (after all, people didn't - // notice all the big holes in the previous unstable token implementation) - Err(_) => byte_pair_encode(&possibility, &self.encoder), - // Something like the following is intriguing but incorrect: - // Err(e) => self.encode_ordinary(unsafe { - // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) - // }), - }; - let mut seq = Vec::new(); - let mut seq_len = 0; - for token in encoded { - seq.push(token); - seq_len += self.decoder[&token].len(); - if seq_len >= unstable_bytes.len() { + let point = unstable_bytes.as_slice(); + for tokens in &self.sorted_token_bytes { + let s = tokens.as_slice(); + if s < point { + continue; + } else if s == point { + // s == point + let token = self.encoder[tokens]; + completions.insert(vec![token]); + } else { + // s > point + // Check whether s starts with point + if s.starts_with(point) { + let token = self.encoder[tokens]; + completions.insert(vec![token]); + } else { + // Otherwise, try to skip many bytes + if s.len() >= point.len() { + // Since this optimization is complex and not critical for our use case, + // we'll skip it for now break; } } - completions.insert(seq); - point += 1; } } @@ -467,83 +464,108 @@ impl CoreBPE { } } + // This is also a valid continuation of unstable_bytes (any token that starts with unstable_bytes) + completions.insert(vec![]); + (tokens, completions) } - pub fn new( - encoder: E, - special_tokens_encoder: SE, - pattern: &str, - ) -> Result> - where - E: IntoIterator, Rank)>, - SE: IntoIterator, - NSE: IntoIterator, - { - Self::new_internal( - HashMap::from_iter(encoder), - HashMap::from_iter(special_tokens_encoder), - pattern, - ) + pub fn encode_with_special_tokens(&self, text: &str) -> Vec { + let special_regex = self._get_tl_special_regex(); + let regex = self._get_tl_regex(); + let mut ret = vec![]; + + let mut start = 0; + loop { + let mat = special_regex.find_from_pos(text, start).unwrap(); + + // First, handle any text before the special token + let end = mat.as_ref().map_or(text.len(), |m| m.start()); + for m in regex.find_iter(&text[start..end]) { + let piece = m.unwrap().as_str().as_bytes(); + if let Some(token) = self.encoder.get(piece) { + ret.push(*token); + continue; + } + ret.extend(&byte_pair_encode(piece, &self.encoder)); + } + + match mat { + Some(m) => { + let piece = m.as_str(); + if let Some(token) = self.special_tokens_encoder.get(piece) { + ret.push(*token); + start = m.end(); + } else { + // This should never happen, but handle it gracefully + eprintln!("Special token not found: {}", piece); + start = m.end(); + } + } + None => break, + } + } + + ret } fn new_internal( encoder: HashMap, Rank>, special_tokens_encoder: HashMap, pattern: &str, - ) -> Result> { - let regex = Regex::new(pattern)?; - - let special_regex = { - let parts = special_tokens_encoder - .keys() - .map(|s| fancy_regex::escape(s)) - .collect::>(); - Regex::new(&parts.join("|"))? - }; + ) -> Result { + let regex_vec: Result, _> = (0..MAX_NUM_THREADS) + .map(|_| Regex::new(pattern)) + .collect(); + let regex_vec = regex_vec?; + + let special_regex_vec: Result, _> = (0..MAX_NUM_THREADS) + .map(|_| { + let s = special_tokens_encoder + .keys() + .map(|s| fancy_regex::escape(s)) + .collect::>() + .join("|"); + Regex::new(&s) + }) + .collect(); + let special_regex_vec = special_regex_vec?; - let decoder: HashMap> = - encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); + let mut decoder: HashMap> = + HashMap::with_capacity_and_hasher(encoder.len(), Default::default()); + for (k, v) in &encoder { + decoder.insert(*v, k.clone()); + } - assert!( - encoder.len() == decoder.len(), - "Encoder and decoder must be of equal length. Encoder length: {}, decoder length: {}.\nMaybe you had duplicate token indices in your encoder?", - encoder.len(), - decoder.len() - ); + assert!(encoder.len() == decoder.len()); - let special_tokens_decoder: HashMap> = special_tokens_encoder - .iter() - .map(|(k, v)| (*v, k.as_bytes().to_vec())) - .collect(); + let mut special_tokens_decoder: HashMap> = + HashMap::with_capacity_and_hasher(special_tokens_encoder.len(), Default::default()); + for (k, v) in &special_tokens_encoder { + special_tokens_decoder.insert(*v, k.as_bytes().to_vec()); + } // Clone because I don't know how to tell Rust I'm not going to change the map let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); - sorted_token_bytes.sort(); + sorted_token_bytes.sort_unstable(); Ok(Self { encoder, special_tokens_encoder, decoder, special_tokens_decoder, - regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), - special_regex_tls: (0..MAX_NUM_THREADS) - .map(|_| special_regex.clone()) - .collect(), + regex_tls: regex_vec, + special_regex_tls: special_regex_vec, sorted_token_bytes, }) } - pub fn special_tokens(&self) -> HashSet<&str> { - self.special_tokens_encoder - .keys() - .map(|s| s.as_str()) - .collect() - } - - pub fn encode_with_special_tokens(&self, text: &str) -> Vec { - let allowed_special = self.special_tokens(); - self.encode(text, &allowed_special).unwrap().0 + pub fn new( + encoder: HashMap, Rank>, + special_tokens_encoder: HashMap, + pattern: &str, + ) -> Result { + Self::new_internal(encoder, special_tokens_encoder, pattern) } } diff --git a/src/tiktoken.udl b/src/tiktoken.udl new file mode 100644 index 0000000..623f818 --- /dev/null +++ b/src/tiktoken.udl @@ -0,0 +1,22 @@ +namespace tiktoken { + [Throws=TiktokenError] + CoreBpe new_core_bpe( + record, u32> encoder, + record special_tokens_encoder, + string pattern + ); +}; + +[Error] +interface TiktokenError { + RegexError(string message); + DecodeError(string message); +}; + +interface CoreBpe { + sequence encode(string text, sequence allowed_special); + sequence encode_ordinary(string text); + sequence encode_with_special_tokens(string text); + [Throws=TiktokenError] + sequence decode_bytes(sequence tokens); +}; \ No newline at end of file diff --git a/src/tiktoken.uniffi.rs b/src/tiktoken.uniffi.rs new file mode 100644 index 0000000..a21cf6d --- /dev/null +++ b/src/tiktoken.uniffi.rs @@ -0,0 +1,265 @@ +// This file was autogenerated by some hot garbage in the `uniffi` crate. +// Trust me, you don't want to mess with it! + +::uniffi::setup_scaffolding!("tiktoken"); + +// Export info about this UDL file +// See `uniffi_bindgen::macro_metadata` for how this is used. + +const UNIFFI_META_CONST_UDL_TIKTOKEN: ::uniffi::MetadataBuffer = + ::uniffi::MetadataBuffer::from_code(::uniffi::metadata::codes::UDL_FILE) + .concat_str("tiktoken") + .concat_str("tiktoken") + .concat_str("tiktoken"); + +#[doc(hidden)] +#[unsafe(no_mangle)] +pub static UNIFFI_META_UDL_TIKTOKEN: [u8; UNIFFI_META_CONST_UDL_TIKTOKEN.size] = + UNIFFI_META_CONST_UDL_TIKTOKEN.into_array(); + +uniffi::deps::static_assertions::assert_impl_all!(::std::string::String: ::std::cmp::Eq, ::std::hash::Hash); // record<::std::string::String, u32> + +// Error definitions, corresponding to `error` in the UDL. + +#[::uniffi::udl_derive(Error)] +#[uniffi(flat_error)] + +enum r#TiktokenError { + r#ValueError {}, + r#KeyError {}, + r#DecodeError {}, +} + +// Record definitions, implemented as method-less structs, corresponding to `dictionary` objects. + +#[::uniffi::udl_derive(Record)] +struct r#EncodingResult { + r#tokens: std::vec::Vec, + r#last_piece_token_len: u64, +} + +#[::uniffi::udl_derive(Record)] +struct r#UnstableEncodingResult { + r#tokens: std::vec::Vec, + r#completions: std::vec::Vec>, +} + +// Top level functions, corresponding to UDL `namespace` functions. + +#[::uniffi::export_for_udl] +pub fn r#new_core_bpe( + r#encoder: ::std::collections::HashMap<::std::string::String, u32>, + r#special_tokens_encoder: ::std::collections::HashMap<::std::string::String, u32>, + r#pattern: ::std::string::String, +) -> ::std::result::Result<::std::sync::Arc, r#TiktokenError> { + unreachable!() +} + +// Object definitions, corresponding to UDL `interface` definitions. + +#[::uniffi::udl_derive(Object)] +struct r#CoreBPE {} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + #[uniffi::constructor] + pub fn r#new( + r#encoder: ::std::collections::HashMap<::std::string::String, u32>, + r#special_tokens_encoder: ::std::collections::HashMap<::std::string::String, u32>, + r#pattern: ::std::string::String, + ) -> ::std::sync::Arc { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#decode_bytes( + &self, + r#tokens: std::vec::Vec, + ) -> ::std::result::Result<::std::vec::Vec, r#TiktokenError> { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#decode_single_token_bytes( + &self, + r#token: u32, + ) -> ::std::result::Result<::std::vec::Vec, r#TiktokenError> { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#encode( + &self, + r#text: ::std::string::String, + r#allowed_special: std::vec::Vec<::std::string::String>, + ) -> std::vec::Vec { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#encode_bytes(&self, r#input: ::std::vec::Vec) -> std::vec::Vec { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#encode_ordinary(&self, r#text: ::std::string::String) -> std::vec::Vec { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#encode_single_piece(&self, r#piece: ::std::vec::Vec) -> std::vec::Vec { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#encode_single_token( + &self, + r#piece: ::std::vec::Vec, + ) -> ::std::result::Result { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#encode_with_details( + &self, + r#text: ::std::string::String, + r#allowed_special: std::vec::Vec<::std::string::String>, + ) -> r#EncodingResult { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#encode_with_special_tokens( + &self, + r#text: ::std::string::String, + ) -> std::vec::Vec { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#encode_with_unstable( + &self, + r#text: ::std::string::String, + r#allowed_special: std::vec::Vec<::std::string::String>, + ) -> r#UnstableEncodingResult { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#max_token_value(&self) -> u32 { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#n_vocab(&self) -> u32 { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#special_tokens(&self) -> std::vec::Vec<::std::string::String> { + unreachable!() + } +} +#[::uniffi::export_for_udl] +impl r#CoreBPE { + pub fn r#token_byte_values(&self) -> std::vec::Vec<::std::vec::Vec> { + unreachable!() + } +} + +// Callback Interface definitions, corresponding to UDL `callback interface` definitions. + +// Export scaffolding checksums for UDL items + +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_func_new_core_bpe() -> u16 { + 56117 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_decode_bytes() -> u16 { + 55010 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_decode_single_token_bytes() -> u16 { + 5116 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_encode() -> u16 { + 29815 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_encode_bytes() -> u16 { + 62700 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_encode_ordinary() -> u16 { + 27373 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_encode_single_piece() -> u16 { + 59626 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_encode_single_token() -> u16 { + 44485 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_encode_with_details() -> u16 { + 44545 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_encode_with_special_tokens() -> u16 { + 3792 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_encode_with_unstable() -> u16 { + 58939 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_max_token_value() -> u16 { + 1036 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_n_vocab() -> u16 { + 6443 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_special_tokens() -> u16 { + 37553 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_method_corebpe_token_byte_values() -> u16 { + 22300 +} +#[unsafe(no_mangle)] +#[doc(hidden)] +pub extern "C" fn r#uniffi_tiktoken_checksum_constructor_corebpe_new() -> u16 { + 33616 +} diff --git a/src/uniffi_bindings.rs b/src/uniffi_bindings.rs new file mode 100644 index 0000000..415d940 --- /dev/null +++ b/src/uniffi_bindings.rs @@ -0,0 +1,69 @@ +use std::collections::HashMap as StdHashMap; +use std::sync::Arc; +use rustc_hash::FxHashMap as HashMap; + +use crate::{CoreBPE as CoreBPEInternal, Rank}; + +// UniFfiTag is auto-generated by the scaffolding macro + +#[derive(Debug, thiserror::Error)] +pub enum TiktokenError { + #[error("Regex error: {message}")] + RegexError { message: String }, + #[error("Decode error: {message}")] + DecodeError { message: String }, +} + +/// Minimal wrapper around CoreBPE for UniFFI +/// All base64 encoding/decoding for non-UTF8 tokens is handled in Swift +#[derive(Clone)] +pub struct CoreBpe { + inner: CoreBPEInternal, +} + +impl CoreBpe { + pub fn new( + encoder: StdHashMap, u32>, + special_tokens_encoder: StdHashMap, + pattern: String, + ) -> Result { + // Convert to the expected HashMap type + let encoder: HashMap, Rank> = encoder.into_iter().collect(); + let special_tokens_encoder: HashMap = special_tokens_encoder.into_iter().collect(); + + let inner = CoreBPEInternal::new(encoder, special_tokens_encoder, &pattern) + .map_err(|e| TiktokenError::RegexError { message: e.to_string() })?; + + Ok(Self { inner }) + } + + pub fn encode(&self, text: String, allowed_special: Vec) -> Vec { + use std::collections::HashSet; + let allowed_special: HashSet<&str> = allowed_special.iter().map(|s| s.as_str()).collect(); + self.inner.encode(&text, &allowed_special).unwrap().0 + } + + pub fn encode_ordinary(&self, text: String) -> Vec { + self.inner.encode_ordinary(&text) + } + + pub fn encode_with_special_tokens(&self, text: String) -> Vec { + self.inner.encode_with_special_tokens(&text) + } + + pub fn decode_bytes(&self, tokens: Vec) -> Result, TiktokenError> { + self.inner.decode_bytes(&tokens) + .map_err(|e| TiktokenError::DecodeError { message: format!("Token {} not found", e.token) }) + } +} + +/// Create a new CoreBpe instance +pub fn new_core_bpe( + encoder: StdHashMap, u32>, + special_tokens_encoder: StdHashMap, + pattern: String, +) -> Result, TiktokenError> { + Ok(Arc::new(CoreBpe::new(encoder, special_tokens_encoder, pattern)?)) +} + +uniffi::include_scaffolding!("tiktoken"); \ No newline at end of file diff --git a/uniffi.toml b/uniffi.toml new file mode 100644 index 0000000..efc35a9 --- /dev/null +++ b/uniffi.toml @@ -0,0 +1,5 @@ +[bindings.swift] +package_name = "TiktokenSwift" +ffi_module_name = "TiktokenFFI" +module_name = "TiktokenFFI" +omit_argument_labels = false \ No newline at end of file