Skip to content

Commit ceb93a4

Browse files
committed
Add normalization methods to SentencePieceProcessor
Add `SentencePieceProcessor::normalize` and `SentencePieceProcessor::normalize_with_offsets`.
1 parent fc9f451 commit ceb93a4

File tree

4 files changed

+211
-12
lines changed

4 files changed

+211
-12
lines changed

sentencepiece-sys/src/bindings.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,23 @@ extern "C" {
151151
extern "C" {
152152
pub fn spp_unk_id(spp: *mut SentencePieceProcessor) -> ::std::os::raw::c_int;
153153
}
154+
extern "C" {
155+
pub fn spp_normalize(
156+
spp: *mut SentencePieceProcessor,
157+
sentence: *const ::std::os::raw::c_char,
158+
sentence_len: usize,
159+
normalized: *mut *mut ::std::os::raw::c_uchar,
160+
normalized_len: *mut usize,
161+
) -> ::std::os::raw::c_int;
162+
}
163+
extern "C" {
164+
pub fn spp_normalize_with_offsets(
165+
spp: *mut SentencePieceProcessor,
166+
sentence: *const ::std::os::raw::c_char,
167+
sentence_len: usize,
168+
normalized: *mut *mut ::std::os::raw::c_uchar,
169+
normalized_len: *mut usize,
170+
offsets: *mut *mut usize,
171+
offsets_len: *mut usize,
172+
) -> ::std::os::raw::c_int;
173+
}

sentencepiece-sys/src/ffi/sentencepiece.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ unsigned char *spp_encode_as_serialized_proto(SentencePieceProcessor *spp, char
7575
return data;
7676
}
7777

78-
7978
unsigned char *spp_sample_encode_as_serialized_proto(SentencePieceProcessor *spp, char const *sentence, size_t sentence_len, size_t *len, size_t nbest, float alpha) {
8079
auto sentence_view = absl::string_view(sentence, sentence_len);
8180
auto serialized = spp->SampleEncodeAsSerializedProto(sentence_view, static_cast<int>(nbest), alpha);
@@ -87,6 +86,66 @@ unsigned char *spp_sample_encode_as_serialized_proto(SentencePieceProcessor *spp
8786
return data;
8887
}
8988

89+
int spp_normalize(SentencePieceProcessor *spp, char const *sentence, size_t sentence_len, unsigned char **normalized, size_t *normalized_len) {
90+
auto sentence_view = absl::string_view(sentence, sentence_len);
91+
std::string normalized_str;
92+
auto status = spp->Normalize(sentence_view, &normalized_str);
93+
94+
if (!status.ok()) {
95+
*normalized = nullptr;
96+
*normalized_len = 0;
97+
return to_underlying_type(status.code());
98+
}
99+
100+
*normalized_len = normalized_str.size();
101+
*normalized = (unsigned char *) malloc(normalized_str.size());
102+
memcpy(*normalized, normalized_str.data(), normalized_str.size());
103+
104+
return to_underlying_type(status.code());
105+
}
106+
107+
int spp_normalize_with_offsets(SentencePieceProcessor *spp, char const *sentence, size_t sentence_len, unsigned char **normalized, size_t *normalized_len, size_t **offsets, size_t *offsets_len) {
108+
auto sentence_view = absl::string_view(sentence, sentence_len);
109+
std::string normalized_str;
110+
std::vector<size_t> norm_to_orig_vec;
111+
auto status = spp->Normalize(sentence_view, &normalized_str, &norm_to_orig_vec);
112+
113+
if (!status.ok()) {
114+
*normalized = nullptr;
115+
*normalized_len = 0;
116+
*offsets = nullptr;
117+
*offsets_len = 0;
118+
return to_underlying_type(status.code());
119+
}
120+
121+
// Allocate and copy normalized string
122+
*normalized_len = normalized_str.size();
123+
*normalized = (unsigned char *) malloc(normalized_str.size());
124+
if (*normalized == nullptr && normalized_str.size() > 0) {
125+
// Allocation failed
126+
*normalized_len = 0;
127+
*offsets = nullptr;
128+
*offsets_len = 0;
129+
return to_underlying_type(sentencepiece::util::StatusCode::kResourceExhausted);
130+
}
131+
memcpy(*normalized, normalized_str.data(), normalized_str.size());
132+
133+
// Allocate and copy offsets
134+
*offsets_len = norm_to_orig_vec.size();
135+
*offsets = (size_t *) malloc(norm_to_orig_vec.size() * sizeof(size_t));
136+
if (*offsets == nullptr && norm_to_orig_vec.size() > 0) {
137+
// Allocation failed - free the already allocated normalized string
138+
free(*normalized);
139+
*normalized = nullptr;
140+
*normalized_len = 0;
141+
*offsets_len = 0;
142+
return to_underlying_type(sentencepiece::util::StatusCode::kResourceExhausted);
143+
}
144+
memcpy(*offsets, norm_to_orig_vec.data(), norm_to_orig_vec.size() * sizeof(size_t));
145+
146+
return to_underlying_type(status.code());
147+
}
148+
90149
int spp_eos_id(SentencePieceProcessor *spp) {
91150
return spp->eos_id();
92151
}

sentencepiece-sys/src/ffi/sentencepiece.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ int spp_piece_size(SentencePieceProcessor *spp);
4242

4343
int spp_unk_id(SentencePieceProcessor *spp);
4444

45+
int spp_normalize(SentencePieceProcessor *spp, char const *sentence, size_t sentence_len, unsigned char **normalized, size_t *normalized_len);
46+
4547
#ifdef __cplusplus
4648
}
4749
#endif

sentencepiece/src/lib.rs

Lines changed: 129 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ use thiserror::Error;
2727

2828
use sentencepiece_sys::{
2929
spp_bos_id, spp_decode_piece_ids, spp_decode_pieces, spp_encode_as_serialized_proto,
30-
spp_eos_id, spp_free, spp_from_serialized_proto, spp_is_unknown, spp_load, spp_new, spp_pad_id,
31-
spp_piece_size, spp_piece_to_id, spp_sample_encode_as_serialized_proto,
32-
spp_to_serialized_proto, spp_unk_id, SentencePieceProcessor as CSentencePieceProcessor,
30+
spp_eos_id, spp_free, spp_from_serialized_proto, spp_is_unknown, spp_load, spp_new,
31+
spp_normalize, spp_normalize_with_offsets, spp_pad_id, spp_piece_size, spp_piece_to_id,
32+
spp_sample_encode_as_serialized_proto, spp_to_serialized_proto, spp_unk_id,
33+
SentencePieceProcessor as CSentencePieceProcessor,
3334
};
3435

3536
mod sentencepiece;
@@ -111,20 +112,20 @@ pub enum CSentencePieceError {
111112
}
112113

113114
/// Small wrapper struct to deallocate data automatically.
114-
struct CData {
115-
data: *const u8,
115+
struct CData<T> {
116+
data: *const T,
116117
len: usize,
117118
}
118119

119-
impl Deref for CData {
120-
type Target = [u8];
120+
impl<T> Deref for CData<T> {
121+
type Target = [T];
121122

122123
fn deref(&self) -> &Self::Target {
123124
unsafe { slice::from_raw_parts(self.data, self.len) }
124125
}
125126
}
126127

127-
impl Drop for CData {
128+
impl<T> Drop for CData<T> {
128129
fn drop(&mut self) {
129130
unsafe { libc::free(self.data as *mut c_void) }
130131
}
@@ -233,7 +234,7 @@ impl SentencePieceProcessor {
233234
)
234235
};
235236

236-
let c_str = CData {
237+
let c_str: CData<u8> = CData {
237238
data: decoded,
238239
len: decoded_len,
239240
};
@@ -276,7 +277,7 @@ impl SentencePieceProcessor {
276277
)
277278
};
278279

279-
let c_str = CData {
280+
let c_str: CData<u8> = CData {
280281
data: decoded,
281282
len: decoded_len,
282283
};
@@ -329,6 +330,100 @@ impl SentencePieceProcessor {
329330
len as usize
330331
}
331332

333+
/// Normalize a sentence.
334+
///
335+
/// This applies the normalization rules specified in the model.
336+
pub fn normalize(&self, sentence: &str) -> Result<String, SentencePieceError> {
337+
let mut normalized = std::ptr::null_mut::<u8>();
338+
let mut normalized_len = 0;
339+
340+
let status = unsafe {
341+
spp_normalize(
342+
self.inner,
343+
sentence.as_ptr() as *const c_char,
344+
sentence.len(),
345+
&mut normalized,
346+
&mut normalized_len,
347+
)
348+
};
349+
350+
let c_str: CData<u8> = CData {
351+
data: normalized,
352+
len: normalized_len,
353+
};
354+
355+
if status == 0 {
356+
let normalized_string = String::from_utf8(c_str.to_owned())
357+
.expect("Normalized sentence is not UTF-8, please report this bug.");
358+
359+
Ok(normalized_string)
360+
} else {
361+
let c_error = match FromPrimitive::from_i32(status) {
362+
Some(error) => error,
363+
None => unreachable!(),
364+
};
365+
Err(SentencePieceError::CError(c_error))
366+
}
367+
}
368+
369+
/// Normalize a sentence.
370+
///
371+
/// This applies the normalization rules specified in the model. Returns a
372+
/// tuple containing the normalized string and a vector mapping each byte in
373+
/// the normalized input to the corresponding byte position in the original
374+
/// input.
375+
pub fn normalize_with_offsets(
376+
&self,
377+
sentence: &str,
378+
) -> Result<(String, Vec<usize>), SentencePieceError> {
379+
let mut normalized = std::ptr::null_mut::<u8>();
380+
let mut normalized_len = 0;
381+
let mut offsets = std::ptr::null_mut::<usize>();
382+
let mut offsets_len = 0;
383+
384+
let status = unsafe {
385+
spp_normalize_with_offsets(
386+
self.inner,
387+
sentence.as_ptr() as *const c_char,
388+
sentence.len(),
389+
&mut normalized,
390+
&mut normalized_len,
391+
&mut offsets,
392+
&mut offsets_len,
393+
)
394+
};
395+
396+
// Create wrapper structs for the allocated memory to manage their lifetimes
397+
let c_normalized: CData<u8> = CData {
398+
data: normalized,
399+
len: normalized_len,
400+
};
401+
402+
// Create a scope to ensure c_offsets is dropped after we're done with it
403+
404+
{
405+
let c_offsets: CData<usize> = CData {
406+
data: offsets,
407+
len: offsets_len,
408+
};
409+
410+
if status == 0 {
411+
let normalized_string = String::from_utf8(c_normalized.to_owned())
412+
.expect("Normalized sentence is not UTF-8, please report this bug.");
413+
414+
let offsets_vec = c_offsets.to_vec();
415+
416+
Ok((normalized_string, offsets_vec))
417+
} else {
418+
let c_error = match FromPrimitive::from_i32(status) {
419+
Some(error) => error,
420+
None => unreachable!(),
421+
};
422+
Err(SentencePieceError::CError(c_error))
423+
}
424+
}
425+
}
426+
332427
pub fn pad_id(&self) -> Option<u32> {
333428
let pad_id = unsafe { spp_pad_id(self.inner) };
334429
if pad_id < 0 {
@@ -350,7 +445,7 @@ impl SentencePieceProcessor {
350445
}
351446
}
352447

353-
fn process_encode_protobuf(c_proto: CData) -> Result<Vec<PieceWithId>, SentencePieceError> {
448+
fn process_encode_protobuf(c_proto: CData<u8>) -> Result<Vec<PieceWithId>, SentencePieceError> {
354449
// Errors are communicated as empty data.
355450
if c_proto.is_empty() {
356451
return Err(SentencePieceError::EncodeError);
@@ -697,6 +792,29 @@ mod tests {
697792
let protobuf_roundtrip = spp.to_serialized_proto();
698793
assert_eq!(protobuf, protobuf_roundtrip);
699794
}
795+
796+
#[test]
797+
fn normalizes_sentence_with_toy_model() {
798+
let model = toy_model().unwrap();
799+
assert_eq!(
800+
model.normalize("KADOKAWA ABC ").unwrap(),
801+
"▁KADOKAWA▁ABC"
802+
);
803+
}
804+
805+
#[test]
806+
fn test_normalize_with_offsets() {
807+
let spp = SentencePieceProcessor::open("testdata/toy.model").unwrap();
808+
809+
let (normalized, offsets) = spp
810+
.normalize_with_offsets("KADOKAWA ABC ")
811+
.unwrap();
812+
assert_eq!(normalized, "▁KADOKAWA▁ABC");
813+
assert_eq!(
814+
offsets,
815+
vec![0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 24, 24, 27, 28, 29, 30]
816+
);
817+
}
700818
}
701819

702820
#[cfg(feature = "albert-tests")]

0 commit comments

Comments
 (0)