Skip to content

Commit 20e948b

Browse files
committed
fix and pass test_with_corpus.py
apply fix to internal_fuzzy_update_codebook
1 parent 47bd6f0 commit 20e948b

File tree

5 files changed

+120
-51
lines changed

5 files changed

+120
-51
lines changed

compression.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Codebook:
2525
The codebook as a list of lists.
2626
"""
2727
...
28-
def to_decoding_dict(self) -> Dict[int, List[int]]:
28+
def to_dict(self) -> Dict[int, List[int]]:
2929
"""
3030
Get the decoding dictionary for the codebook.
3131
"""

example/run_lzw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
print(f"input ids: {ids}")
2828
print(f"compressed_ids: {compressed_ids}")
2929
print(f"attention_mask: {attention_mask}")
30-
print(f"codebook: {codebook.to_decoding_dict()}")
30+
print(f"codebook: {codebook.to_dict()}")
3131

3232

3333
# decompress

src/lib.rs

Lines changed: 72 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use bumpalo::{collections::Vec as BumpVec, Bump};
2-
use hashbrown::{HashMap, HashSet};
2+
use std::collections::{HashMap, HashSet, BTreeMap};
33
use itertools::Itertools;
44
use pyo3::prelude::*;
55
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
@@ -33,14 +33,14 @@ impl CodebookConfig {
3333
max_codebook_size: usize,
3434
max_subtokens: usize,
3535
pad_token_id: usize,
36-
disabled_ids: HashSet<usize>,
36+
disabled_ids: Option<HashSet<usize>>,
3737
) -> Self {
3838
Self {
3939
initial_vocab_size,
4040
max_codebook_size,
4141
max_subtokens,
4242
pad_token_id,
43-
disabled_ids,
43+
disabled_ids: disabled_ids.unwrap_or_default(),
4444
}
4545
}
4646
}
@@ -196,8 +196,8 @@ impl Codebook {
196196
result
197197
}
198198

199-
pub fn to_decoding_dict(&self) -> HashMap<usize, Vec<usize>> {
200-
let mut result = HashMap::with_capacity(self.base_ids2hyper_id_map.len());
199+
pub fn to_dict(&self) -> BTreeMap<usize, Vec<usize>> {
200+
let mut result = BTreeMap::new();
201201
for (ids, id) in self.base_ids2hyper_id_map.iter() {
202202
result.insert(*id, ids.clone());
203203
}
@@ -291,6 +291,8 @@ impl LZWCompressor {
291291

292292
let id = ids[i];
293293
i += 1;
294+
log::debug!("coming id: {}", id);
295+
log::debug!("buffer_ids_to_merge: {:?}", buffer_ids_to_merge);
294296

295297
if self.config.disabled_ids.contains(&id) {
296298
if buffer_ids_to_merge.len() > 0 {
@@ -478,7 +480,7 @@ impl LZWCompressor {
478480
if self.config.disabled_ids.contains(&id) {
479481
previous_ids.clear();
480482
output_ids.push(id);
481-
log::debug!("emitting id: {}", id);
483+
log::debug!("emitting disabled id: {}", id);
482484
continue;
483485
}
484486

@@ -495,7 +497,7 @@ impl LZWCompressor {
495497
// 1. the id is a new hyper id because of force merge
496498
} else if previous_ids.len() == self.config.max_subtokens {
497499
decoded_ids = previous_ids.clone();
498-
// 2. the id is a new hyper id because of cSc pattern merge
500+
// 2. the id is a new hyper id because of cScSc pattern merge
499501
} else {
500502
decoded_ids = previous_ids.clone();
501503
decoded_ids.push(previous_ids[0]);
@@ -507,34 +509,53 @@ impl LZWCompressor {
507509
previous_ids = decoded_ids.clone();
508510
output_ids.extend_from_slice(&decoded_ids);
509511
log::debug!("emitting id: {:?}", decoded_ids);
512+
continue;
510513
}
511514
// we have decoded the id, we can add it to the output
512515
output_ids.extend_from_slice(&decoded_ids);
516+
log::debug!("emitting id: {:?}", decoded_ids);
513517

514518
// the remaining part is to update the codebook if needed
519+
520+
if next_id == self.config.initial_vocab_size + self.config.max_codebook_size {
521+
log::debug!("max codebook size reached, clearing buffer_ids_to_merge");
522+
previous_ids.clear();
523+
continue;
524+
}
525+
526+
// starting case
515527
if previous_ids.len() == 0 {
516528
previous_ids = decoded_ids;
517529
continue;
518530
}
519531

520-
while next_id < self.config.initial_vocab_size + self.config.max_codebook_size
521-
&& previous_ids.len() < self.config.max_subtokens && decoded_ids.len() > 0
522-
{
523-
previous_ids.push(decoded_ids[0]);
532+
// the buffer is max size and the buffer is the previous ID
533+
// so it must not be a new hyper id but an existing one
534+
// we just clear the buffer and continue
535+
if previous_ids.len() == self.config.max_subtokens {
536+
assert!(existing_codes.contains(&previous_ids), "previous_ids: {:?} not in existing_codes: {:?}", previous_ids, existing_codes);
537+
previous_ids = decoded_ids.clone();
538+
continue;
539+
} else {
524540

525-
if !existing_codes.contains(&previous_ids) {
526-
codebook.insert(next_id, previous_ids.clone());
527-
log::debug!("inserting: {:?} -> {:?}", previous_ids, next_id);
528-
next_id += 1;
529-
existing_codes.insert(previous_ids.clone());
530-
previous_ids = decoded_ids.clone();
531-
break;
532-
} else if previous_ids.len() == self.config.max_subtokens {
533-
previous_ids = decoded_ids.clone();
534-
break;
535-
}
541+
while decoded_ids.len() > 0
542+
{
543+
previous_ids.push(decoded_ids[0]);
544+
545+
if !existing_codes.contains(&previous_ids) {
546+
codebook.insert(next_id, previous_ids.clone());
547+
log::debug!("inserting: {:?} -> {:?}", previous_ids, next_id);
548+
next_id += 1;
549+
existing_codes.insert(previous_ids.clone());
550+
previous_ids = decoded_ids.clone();
551+
break;
552+
} else if previous_ids.len() == self.config.max_subtokens {
553+
previous_ids = decoded_ids.clone();
554+
break;
555+
}
536556

537-
decoded_ids.remove(0);
557+
decoded_ids.remove(0);
558+
}
538559
}
539560
}
540561

@@ -611,7 +632,7 @@ impl LZWCompressor {
611632
max_codebook_size,
612633
max_subtokens,
613634
pad_token_id,
614-
disabled_ids,
635+
Some(disabled_ids)
615636
),
616637
}
617638
}
@@ -875,8 +896,8 @@ impl CodebookManager {
875896
let mut current_ids: Vec<usize>;
876897
if maybe_hid < config.initial_vocab_size {
877898
current_ids = vec![maybe_hid];
878-
} else if let Some(entry) = codebook.get_base_ids(maybe_hid) {
879-
current_ids = entry.clone();
899+
} else if let Some(base_ids) = codebook.get_base_ids(maybe_hid) {
900+
current_ids = base_ids.clone();
880901
// the following are cases when the maybe_hid is an unknown hyper-id
881902
// (1) the buffer was full and it was inserted in the codebook
882903
// TODO, I think the following is not needed, because
@@ -898,33 +919,38 @@ impl CodebookManager {
898919

899920
}
900921

922+
if state.next_id == config.initial_vocab_size + config.max_codebook_size {
923+
log::debug!("max_codebook_size reached, clearing buffer_ids_to_merge");
924+
state.buffer_ids_to_merge.clear();
925+
continue;
926+
}
927+
901928
// Starting time
902929
if state.buffer_ids_to_merge.len() == 0 {
903930
state.buffer_ids_to_merge = current_ids.clone();
904931
continue;
905932
}
906933

907-
while state.next_id < config.initial_vocab_size + config.max_codebook_size
908-
&& state.buffer_ids_to_merge.len() < config.max_subtokens && current_ids.len() > 0
909-
{
910-
state.buffer_ids_to_merge.push(current_ids[0]);
911-
912-
// check if it's already in the codebook
913-
if !codebook.contains_key(&state.buffer_ids_to_merge) {
914-
codebook.insert(state.buffer_ids_to_merge.clone(), state.next_id);
915-
log::debug!("inserting: {:?} -> {:?}", state.buffer_ids_to_merge, state.next_id);
916-
state.next_id += 1;
917-
state.buffer_ids_to_merge = current_ids.clone();
918-
break;
919-
} // reaching max_subtokens, we need to insert the current_ids into the codebook
920-
else if state.buffer_ids_to_merge.len() == config.max_subtokens {
921-
state.buffer_ids_to_merge = current_ids.clone();
922-
break;
923-
}
924-
current_ids.remove(0);
934+
if state.buffer_ids_to_merge.len() == config.max_subtokens {
935+
state.buffer_ids_to_merge = current_ids.clone();
936+
continue;
937+
} else {
938+
while current_ids.len() > 0 {
939+
state.buffer_ids_to_merge.push(current_ids[0]);
925940

941+
if !codebook.contains_key(&state.buffer_ids_to_merge) {
942+
codebook.insert(state.buffer_ids_to_merge.clone(), state.next_id);
943+
log::debug!("inserting: {:?} -> {:?}", state.buffer_ids_to_merge, state.next_id);
944+
state.next_id += 1;
945+
state.buffer_ids_to_merge = current_ids.clone();
946+
break;
947+
} else if state.buffer_ids_to_merge.len() == config.max_subtokens {
948+
state.buffer_ids_to_merge = current_ids.clone();
949+
break;
950+
}
951+
current_ids.remove(0);
952+
}
926953
}
927-
928954
}
929955
}
930956
}
@@ -1031,7 +1057,7 @@ impl CodebookManager {
10311057
_ => panic!("Invalid algorithm: {}", self.algorithm),
10321058
}
10331059

1034-
// collect buffered updates from this states codebook
1060+
// collect buffered updates from this state's codebook
10351061
state
10361062
.codebook
10371063
.borrow_mut(py)

tests/long_test_with_corpus.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
from tqdm import tqdm
3+
from huggingface_hub import hf_hub_download
4+
from zip2zip_compression import LZWCompressor
5+
6+
import pytest
7+
8+
9+
filename = hf_hub_download(
10+
repo_id="kjj0/fineweb10B-gpt2",
11+
filename="fineweb_train_000001.bin",
12+
repo_type="dataset",
13+
)
14+
15+
header = torch.from_file(filename, False, 256, dtype=torch.int32)
16+
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
17+
assert header[1] == 1, "unsupported version"
18+
num_tokens = int(header[2])
19+
20+
with open(filename, "rb", buffering=0) as f:
21+
tokens = torch.empty(num_tokens, dtype=torch.uint16)
22+
f.seek(256 * 4)
23+
nbytes = f.readinto(tokens.numpy())
24+
assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
25+
26+
compressor = LZWCompressor(
27+
initial_vocab_size=50257,
28+
max_codebook_size=2048,
29+
max_subtokens=8,
30+
pad_token_id=50256,
31+
disabled_ids=[0, 1, 2, 3],
32+
)
33+
34+
seq_len = 10_000
35+
for ts in tqdm(tokens.view(-1, seq_len)):
36+
ts_list = ts.tolist()
37+
encoded_ts, _, _ = compressor.encode(ts_list)
38+
decoded_ts = compressor.decode(encoded_ts)
39+
40+
assert len(decoded_ts) == len(
41+
ts_list
42+
), f"decoded length mismatch: {len(decoded_ts)} != {len(ts_list)}"
43+
assert decoded_ts == ts_list, "decoded content mismatch"

tests/test.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
use hashbrown::{HashMap, HashSet};
1+
use std::collections::{HashMap, HashSet};
22
use zip2zip_compression::{LZWCompressor, CodebookConfig, PaddingStrategy, Codebook};
33

44
fn get_alphabet_codebook_config() -> CodebookConfig {
5-
let mut disabled_ids = HashSet::new();
5+
let mut disabled_ids: HashSet<usize> = HashSet::new();
66
disabled_ids.insert(26); // 'z'
77

88
// 26 letters + 1 for the pad token
9-
CodebookConfig::new(27, 100, 5, 0, disabled_ids.clone())
9+
CodebookConfig::new(27, 100, 5, 0, Some(disabled_ids))
1010
}
1111

1212
fn get_base_letter_to_id_map() -> HashMap<char, usize> {

0 commit comments

Comments
 (0)