1
1
use bumpalo:: { collections:: Vec as BumpVec , Bump } ;
2
- use hashbrown :: { HashMap , HashSet } ;
2
+ use std :: collections :: { HashMap , HashSet , BTreeMap } ;
3
3
use itertools:: Itertools ;
4
4
use pyo3:: prelude:: * ;
5
5
use rayon:: iter:: { IntoParallelRefIterator , ParallelIterator } ;
@@ -33,14 +33,14 @@ impl CodebookConfig {
33
33
max_codebook_size : usize ,
34
34
max_subtokens : usize ,
35
35
pad_token_id : usize ,
36
- disabled_ids : HashSet < usize > ,
36
+ disabled_ids : Option < HashSet < usize > > ,
37
37
) -> Self {
38
38
Self {
39
39
initial_vocab_size,
40
40
max_codebook_size,
41
41
max_subtokens,
42
42
pad_token_id,
43
- disabled_ids,
43
+ disabled_ids : disabled_ids . unwrap_or_default ( ) ,
44
44
}
45
45
}
46
46
}
@@ -196,8 +196,8 @@ impl Codebook {
196
196
result
197
197
}
198
198
199
- pub fn to_decoding_dict ( & self ) -> HashMap < usize , Vec < usize > > {
200
- let mut result = HashMap :: with_capacity ( self . base_ids2hyper_id_map . len ( ) ) ;
199
+ pub fn to_dict ( & self ) -> BTreeMap < usize , Vec < usize > > {
200
+ let mut result = BTreeMap :: new ( ) ;
201
201
for ( ids, id) in self . base_ids2hyper_id_map . iter ( ) {
202
202
result. insert ( * id, ids. clone ( ) ) ;
203
203
}
@@ -291,6 +291,8 @@ impl LZWCompressor {
291
291
292
292
let id = ids[ i] ;
293
293
i += 1 ;
294
+ log:: debug!( "coming id: {}" , id) ;
295
+ log:: debug!( "buffer_ids_to_merge: {:?}" , buffer_ids_to_merge) ;
294
296
295
297
if self . config . disabled_ids . contains ( & id) {
296
298
if buffer_ids_to_merge. len ( ) > 0 {
@@ -478,7 +480,7 @@ impl LZWCompressor {
478
480
if self . config . disabled_ids . contains ( & id) {
479
481
previous_ids. clear ( ) ;
480
482
output_ids. push ( id) ;
481
- log:: debug!( "emitting id: {}" , id) ;
483
+ log:: debug!( "emitting disabled id: {}" , id) ;
482
484
continue ;
483
485
}
484
486
@@ -495,7 +497,7 @@ impl LZWCompressor {
495
497
// 1. the id is a new hyper id because of force merge
496
498
} else if previous_ids. len ( ) == self . config . max_subtokens {
497
499
decoded_ids = previous_ids. clone ( ) ;
498
- // 2. the id is a new hyper id because of cSc pattern merge
500
+ // 2. the id is a new hyper id because of cScSc pattern merge
499
501
} else {
500
502
decoded_ids = previous_ids. clone ( ) ;
501
503
decoded_ids. push ( previous_ids[ 0 ] ) ;
@@ -507,34 +509,53 @@ impl LZWCompressor {
507
509
previous_ids = decoded_ids. clone ( ) ;
508
510
output_ids. extend_from_slice ( & decoded_ids) ;
509
511
log:: debug!( "emitting id: {:?}" , decoded_ids) ;
512
+ continue ;
510
513
}
511
514
// we have decoded the id, we can add it to the output
512
515
output_ids. extend_from_slice ( & decoded_ids) ;
516
+ log:: debug!( "emitting id: {:?}" , decoded_ids) ;
513
517
514
518
// the remaining part is to update the codebook if needed
519
+
520
+ if next_id == self . config . initial_vocab_size + self . config . max_codebook_size {
521
+ log:: debug!( "max codebook size reached, clearing buffer_ids_to_merge" ) ;
522
+ previous_ids. clear ( ) ;
523
+ continue ;
524
+ }
525
+
526
+ // starting case
515
527
if previous_ids. len ( ) == 0 {
516
528
previous_ids = decoded_ids;
517
529
continue ;
518
530
}
519
531
520
- while next_id < self . config . initial_vocab_size + self . config . max_codebook_size
521
- && previous_ids. len ( ) < self . config . max_subtokens && decoded_ids. len ( ) > 0
522
- {
523
- previous_ids. push ( decoded_ids[ 0 ] ) ;
532
+ // the buffer is max size and the buffer is the previous ID
533
+ // so it must not be a new hyper id but an existing one
534
+ // we just clear the buffer and continue
535
+ if previous_ids. len ( ) == self . config . max_subtokens {
536
+ assert ! ( existing_codes. contains( & previous_ids) , "previous_ids: {:?} not in existing_codes: {:?}" , previous_ids, existing_codes) ;
537
+ previous_ids = decoded_ids. clone ( ) ;
538
+ continue ;
539
+ } else {
524
540
525
- if !existing_codes. contains ( & previous_ids) {
526
- codebook. insert ( next_id, previous_ids. clone ( ) ) ;
527
- log:: debug!( "inserting: {:?} -> {:?}" , previous_ids, next_id) ;
528
- next_id += 1 ;
529
- existing_codes. insert ( previous_ids. clone ( ) ) ;
530
- previous_ids = decoded_ids. clone ( ) ;
531
- break ;
532
- } else if previous_ids. len ( ) == self . config . max_subtokens {
533
- previous_ids = decoded_ids. clone ( ) ;
534
- break ;
535
- }
541
+ while decoded_ids. len ( ) > 0
542
+ {
543
+ previous_ids. push ( decoded_ids[ 0 ] ) ;
544
+
545
+ if !existing_codes. contains ( & previous_ids) {
546
+ codebook. insert ( next_id, previous_ids. clone ( ) ) ;
547
+ log:: debug!( "inserting: {:?} -> {:?}" , previous_ids, next_id) ;
548
+ next_id += 1 ;
549
+ existing_codes. insert ( previous_ids. clone ( ) ) ;
550
+ previous_ids = decoded_ids. clone ( ) ;
551
+ break ;
552
+ } else if previous_ids. len ( ) == self . config . max_subtokens {
553
+ previous_ids = decoded_ids. clone ( ) ;
554
+ break ;
555
+ }
536
556
537
- decoded_ids. remove ( 0 ) ;
557
+ decoded_ids. remove ( 0 ) ;
558
+ }
538
559
}
539
560
}
540
561
@@ -611,7 +632,7 @@ impl LZWCompressor {
611
632
max_codebook_size,
612
633
max_subtokens,
613
634
pad_token_id,
614
- disabled_ids,
635
+ Some ( disabled_ids)
615
636
) ,
616
637
}
617
638
}
@@ -875,8 +896,8 @@ impl CodebookManager {
875
896
let mut current_ids: Vec < usize > ;
876
897
if maybe_hid < config. initial_vocab_size {
877
898
current_ids = vec ! [ maybe_hid] ;
878
- } else if let Some ( entry ) = codebook. get_base_ids ( maybe_hid) {
879
- current_ids = entry . clone ( ) ;
899
+ } else if let Some ( base_ids ) = codebook. get_base_ids ( maybe_hid) {
900
+ current_ids = base_ids . clone ( ) ;
880
901
// the following are cases when the maybe_hid is an unknown hyper-id
881
902
// (1) the buffer was full and it was inserted in the codebook
882
903
// TODO, I think the following is not needed, because
@@ -898,33 +919,38 @@ impl CodebookManager {
898
919
899
920
}
900
921
922
+ if state. next_id == config. initial_vocab_size + config. max_codebook_size {
923
+ log:: debug!( "max_codebook_size reached, clearing buffer_ids_to_merge" ) ;
924
+ state. buffer_ids_to_merge . clear ( ) ;
925
+ continue ;
926
+ }
927
+
901
928
// Starting time
902
929
if state. buffer_ids_to_merge . len ( ) == 0 {
903
930
state. buffer_ids_to_merge = current_ids. clone ( ) ;
904
931
continue ;
905
932
}
906
933
907
- while state. next_id < config. initial_vocab_size + config. max_codebook_size
908
- && state. buffer_ids_to_merge . len ( ) < config. max_subtokens && current_ids. len ( ) > 0
909
- {
910
- state. buffer_ids_to_merge . push ( current_ids[ 0 ] ) ;
911
-
912
- // check if it's already in the codebook
913
- if !codebook. contains_key ( & state. buffer_ids_to_merge ) {
914
- codebook. insert ( state. buffer_ids_to_merge . clone ( ) , state. next_id ) ;
915
- log:: debug!( "inserting: {:?} -> {:?}" , state. buffer_ids_to_merge, state. next_id) ;
916
- state. next_id += 1 ;
917
- state. buffer_ids_to_merge = current_ids. clone ( ) ;
918
- break ;
919
- } // reaching max_subtokens, we need to insert the current_ids into the codebook
920
- else if state. buffer_ids_to_merge . len ( ) == config. max_subtokens {
921
- state. buffer_ids_to_merge = current_ids. clone ( ) ;
922
- break ;
923
- }
924
- current_ids. remove ( 0 ) ;
934
+ if state. buffer_ids_to_merge . len ( ) == config. max_subtokens {
935
+ state. buffer_ids_to_merge = current_ids. clone ( ) ;
936
+ continue ;
937
+ } else {
938
+ while current_ids. len ( ) > 0 {
939
+ state. buffer_ids_to_merge . push ( current_ids[ 0 ] ) ;
925
940
941
+ if !codebook. contains_key ( & state. buffer_ids_to_merge ) {
942
+ codebook. insert ( state. buffer_ids_to_merge . clone ( ) , state. next_id ) ;
943
+ log:: debug!( "inserting: {:?} -> {:?}" , state. buffer_ids_to_merge, state. next_id) ;
944
+ state. next_id += 1 ;
945
+ state. buffer_ids_to_merge = current_ids. clone ( ) ;
946
+ break ;
947
+ } else if state. buffer_ids_to_merge . len ( ) == config. max_subtokens {
948
+ state. buffer_ids_to_merge = current_ids. clone ( ) ;
949
+ break ;
950
+ }
951
+ current_ids. remove ( 0 ) ;
952
+ }
926
953
}
927
-
928
954
}
929
955
}
930
956
}
@@ -1031,7 +1057,7 @@ impl CodebookManager {
1031
1057
_ => panic ! ( "Invalid algorithm: {}" , self . algorithm) ,
1032
1058
}
1033
1059
1034
- // collect buffered updates from this state’ s codebook
1060
+ // collect buffered updates from this state' s codebook
1035
1061
state
1036
1062
. codebook
1037
1063
. borrow_mut ( py)
0 commit comments