@@ -27,9 +27,10 @@ use thiserror::Error;
2727
2828use sentencepiece_sys:: {
2929 spp_bos_id, spp_decode_piece_ids, spp_decode_pieces, spp_encode_as_serialized_proto,
30- spp_eos_id, spp_free, spp_from_serialized_proto, spp_is_unknown, spp_load, spp_new, spp_pad_id,
31- spp_piece_size, spp_piece_to_id, spp_sample_encode_as_serialized_proto,
32- spp_to_serialized_proto, spp_unk_id, SentencePieceProcessor as CSentencePieceProcessor ,
30+ spp_eos_id, spp_free, spp_from_serialized_proto, spp_is_unknown, spp_load, spp_new,
31+ spp_normalize, spp_normalize_with_offsets, spp_pad_id, spp_piece_size, spp_piece_to_id,
32+ spp_sample_encode_as_serialized_proto, spp_to_serialized_proto, spp_unk_id,
33+ SentencePieceProcessor as CSentencePieceProcessor ,
3334} ;
3435
3536mod sentencepiece;
@@ -111,20 +112,20 @@ pub enum CSentencePieceError {
111112}
112113
113114/// Small wrapper struct to deallocate data automatically.
114- struct CData {
115- data : * const u8 ,
115+ struct CData < T > {
116+ data : * const T ,
116117 len : usize ,
117118}
118119
119- impl Deref for CData {
120- type Target = [ u8 ] ;
120+ impl < T > Deref for CData < T > {
121+ type Target = [ T ] ;
121122
122123 fn deref ( & self ) -> & Self :: Target {
123124 unsafe { slice:: from_raw_parts ( self . data , self . len ) }
124125 }
125126}
126127
127- impl Drop for CData {
128+ impl < T > Drop for CData < T > {
128129 fn drop ( & mut self ) {
129130 unsafe { libc:: free ( self . data as * mut c_void ) }
130131 }
@@ -233,7 +234,7 @@ impl SentencePieceProcessor {
233234 )
234235 } ;
235236
236- let c_str = CData {
237+ let c_str: CData < u8 > = CData {
237238 data : decoded,
238239 len : decoded_len,
239240 } ;
@@ -276,7 +277,7 @@ impl SentencePieceProcessor {
276277 )
277278 } ;
278279
279- let c_str = CData {
280+ let c_str: CData < u8 > = CData {
280281 data : decoded,
281282 len : decoded_len,
282283 } ;
@@ -329,6 +330,100 @@ impl SentencePieceProcessor {
329330 len as usize
330331 }
331332
333+ /// Normalize a sentence.
334+ ///
335+ /// This applies the normalization rules specified in the model.
336+ pub fn normalize ( & self , sentence : & str ) -> Result < String , SentencePieceError > {
337+ let mut normalized = std:: ptr:: null_mut :: < u8 > ( ) ;
338+ let mut normalized_len = 0 ;
339+
340+ let status = unsafe {
341+ spp_normalize (
342+ self . inner ,
343+ sentence. as_ptr ( ) as * const c_char ,
344+ sentence. len ( ) ,
345+ & mut normalized,
346+ & mut normalized_len,
347+ )
348+ } ;
349+
350+ let c_str: CData < u8 > = CData {
351+ data : normalized,
352+ len : normalized_len,
353+ } ;
354+
355+ if status == 0 {
356+ let normalized_string = String :: from_utf8 ( c_str. to_owned ( ) )
357+ . expect ( "Normalized sentence is not UTF-8, please report this bug." ) ;
358+
359+ Ok ( normalized_string)
360+ } else {
361+ let c_error = match FromPrimitive :: from_i32 ( status) {
362+ Some ( error) => error,
363+ None => unreachable ! ( ) ,
364+ } ;
365+ Err ( SentencePieceError :: CError ( c_error) )
366+ }
367+ }
368+
369+ /// Normalize a sentence.
370+ ///
371+ /// This applies the normalization rules specified in the model. Returns a
372+ /// tuple containing the normalized string and a vector mapping each byte in
373+ /// the normalized input to the corresponding byte position in the original
374+ /// input.
375+ pub fn normalize_with_offsets (
376+ & self ,
377+ sentence : & str ,
378+ ) -> Result < ( String , Vec < usize > ) , SentencePieceError > {
379+ let mut normalized = std:: ptr:: null_mut :: < u8 > ( ) ;
380+ let mut normalized_len = 0 ;
381+ let mut offsets = std:: ptr:: null_mut :: < usize > ( ) ;
382+ let mut offsets_len = 0 ;
383+
384+ let status = unsafe {
385+ spp_normalize_with_offsets (
386+ self . inner ,
387+ sentence. as_ptr ( ) as * const c_char ,
388+ sentence. len ( ) ,
389+ & mut normalized,
390+ & mut normalized_len,
391+ & mut offsets,
392+ & mut offsets_len,
393+ )
394+ } ;
395+
396+ // Create wrapper structs for the allocated memory to manage their lifetimes
397+ let c_normalized: CData < u8 > = CData {
398+ data : normalized,
399+ len : normalized_len,
400+ } ;
401+
402+ // Create a scope to ensure c_offsets is dropped after we're done with it
403+
404+ {
405+ let c_offsets: CData < usize > = CData {
406+ data : offsets,
407+ len : offsets_len,
408+ } ;
409+
410+ if status == 0 {
411+ let normalized_string = String :: from_utf8 ( c_normalized. to_owned ( ) )
412+ . expect ( "Normalized sentence is not UTF-8, please report this bug." ) ;
413+
414+ let offsets_vec = c_offsets. to_vec ( ) ;
415+
416+ Ok ( ( normalized_string, offsets_vec) )
417+ } else {
418+ let c_error = match FromPrimitive :: from_i32 ( status) {
419+ Some ( error) => error,
420+ None => unreachable ! ( ) ,
421+ } ;
422+ Err ( SentencePieceError :: CError ( c_error) )
423+ }
424+ }
425+ }
426+
332427 pub fn pad_id ( & self ) -> Option < u32 > {
333428 let pad_id = unsafe { spp_pad_id ( self . inner ) } ;
334429 if pad_id < 0 {
@@ -350,7 +445,7 @@ impl SentencePieceProcessor {
350445 }
351446 }
352447
353- fn process_encode_protobuf ( c_proto : CData ) -> Result < Vec < PieceWithId > , SentencePieceError > {
448+ fn process_encode_protobuf ( c_proto : CData < u8 > ) -> Result < Vec < PieceWithId > , SentencePieceError > {
354449 // Errors are communicated as empty data.
355450 if c_proto. is_empty ( ) {
356451 return Err ( SentencePieceError :: EncodeError ) ;
@@ -697,6 +792,29 @@ mod tests {
697792 let protobuf_roundtrip = spp. to_serialized_proto ( ) ;
698793 assert_eq ! ( protobuf, protobuf_roundtrip) ;
699794 }
795+
796+ #[ test]
797+ fn normalizes_sentence_with_toy_model ( ) {
798+ let model = toy_model ( ) . unwrap ( ) ;
799+ assert_eq ! (
800+ model. normalize( "KADOKAWA ABC " ) . unwrap( ) ,
801+ "▁KADOKAWA▁ABC"
802+ ) ;
803+ }
804+
805+ #[ test]
806+ fn test_normalize_with_offsets ( ) {
807+ let spp = SentencePieceProcessor :: open ( "testdata/toy.model" ) . unwrap ( ) ;
808+
809+ let ( normalized, offsets) = spp
810+ . normalize_with_offsets ( "KADOKAWA ABC " )
811+ . unwrap ( ) ;
812+ assert_eq ! ( normalized, "▁KADOKAWA▁ABC" ) ;
813+ assert_eq ! (
814+ offsets,
815+ vec![ 0 , 0 , 0 , 0 , 3 , 6 , 9 , 12 , 15 , 18 , 21 , 24 , 24 , 24 , 27 , 28 , 29 , 30 ]
816+ ) ;
817+ }
700818}
701819
702820#[ cfg( feature = "albert-tests" ) ]
0 commit comments