Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sentencepiece-sys/source
Submodule source updated 51 files
+23 −0 .github/dependabot.yml
+5 −1 .github/workflows/cifuzz.yml
+13 −6 .github/workflows/cmake.yml
+45 −0 .github/workflows/cross_build.yml
+4 −0 .github/workflows/requirements/base.in
+316 −0 .github/workflows/requirements/base.txt
+2 −0 .github/workflows/requirements/cibuildwheel.in
+36 −0 .github/workflows/requirements/cibuildwheel.txt
+15 −10 .github/workflows/wheel.yml
+35 −5 CMakeLists.txt
+1 −1 VERSION.txt
+1 −1 doc/api.md
+6 −4 python/setup.py
+162 −25 python/src/sentencepiece/__init__.py
+1 −1 python/src/sentencepiece/_version.py
+229 −20 python/src/sentencepiece/sentencepiece.i
+30 −743 python/src/sentencepiece/sentencepiece_model_pb2.py
+14 −179 python/src/sentencepiece/sentencepiece_pb2.py
+1,695 −709 python/src/sentencepiece/sentencepiece_wrap.cxx
+289 −66 python/test/sentencepiece_test.py
+3 −2 sentencepiece.pc.in
+13 −5 src/CMakeLists.txt
+3 −3 src/bpe_model.cc
+5 −4 src/bpe_model_trainer.h
+199 −154 src/builtin_pb/sentencepiece_model.pb.cc
+191 −96 src/builtin_pb/sentencepiece_model.pb.h
+5 −15 src/common.h
+10 −8 src/filesystem.cc
+7 −7 src/model_factory.cc
+20 −3 src/model_interface.cc
+15 −12 src/normalizer.cc
+5 −1 src/sentencepiece_model.proto
+61 −17 src/sentencepiece_processor.cc
+28 −3 src/sentencepiece_processor.h
+44 −45 src/sentencepiece_processor_test.cc
+138 −1 src/sentencepiece_trainer.cc
+54 −0 src/sentencepiece_trainer.h
+105 −1 src/sentencepiece_trainer_test.cc
+2 −0 src/spec_parser.h
+3 −0 src/spm_train_main.cc
+2 −2 src/testharness.h
+12 −12 src/trainer_factory.cc
+14 −13 src/trainer_interface.cc
+1 −2 src/unigram_model.cc
+97 −59 src/unigram_model_trainer.cc
+2 −3 src/unigram_model_trainer_test.cc
+40 −15 src/util.cc
+10 −0 src/util.h
+8 −13 third_party/absl/container/btree_set.h
+0 −71 third_party/absl/memory/memory.h
+0 −31 third_party/absl/random/distributions.h
20 changes: 20 additions & 0 deletions sentencepiece-sys/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,23 @@ extern "C" {
extern "C" {
pub fn spp_unk_id(spp: *mut SentencePieceProcessor) -> ::std::os::raw::c_int;
}
extern "C" {
pub fn spp_normalize(
spp: *mut SentencePieceProcessor,
sentence: *const ::std::os::raw::c_char,
sentence_len: usize,
normalized: *mut *mut ::std::os::raw::c_uchar,
normalized_len: *mut usize,
) -> ::std::os::raw::c_int;
}
extern "C" {
pub fn spp_normalize_with_offsets(
spp: *mut SentencePieceProcessor,
sentence: *const ::std::os::raw::c_char,
sentence_len: usize,
normalized: *mut *mut ::std::os::raw::c_uchar,
normalized_len: *mut usize,
offsets: *mut *mut usize,
offsets_len: *mut usize,
) -> ::std::os::raw::c_int;
}
61 changes: 60 additions & 1 deletion sentencepiece-sys/src/ffi/sentencepiece.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ unsigned char *spp_encode_as_serialized_proto(SentencePieceProcessor *spp, char
return data;
}


unsigned char *spp_sample_encode_as_serialized_proto(SentencePieceProcessor *spp, char const *sentence, size_t sentence_len, size_t *len, size_t nbest, float alpha) {
auto sentence_view = absl::string_view(sentence, sentence_len);
auto serialized = spp->SampleEncodeAsSerializedProto(sentence_view, static_cast<int>(nbest), alpha);
Expand All @@ -87,6 +86,66 @@ unsigned char *spp_sample_encode_as_serialized_proto(SentencePieceProcessor *spp
return data;
}

int spp_normalize(SentencePieceProcessor *spp, char const *sentence, size_t sentence_len, unsigned char **normalized, size_t *normalized_len) {
auto sentence_view = absl::string_view(sentence, sentence_len);
std::string normalized_str;
auto status = spp->Normalize(sentence_view, &normalized_str);

if (!status.ok()) {
*normalized = nullptr;
*normalized_len = 0;
return to_underlying_type(status.code());
}

*normalized_len = normalized_str.size();
*normalized = (unsigned char *) malloc(normalized_str.size());
memcpy(*normalized, normalized_str.data(), normalized_str.size());

return to_underlying_type(status.code());
}

int spp_normalize_with_offsets(SentencePieceProcessor *spp, char const *sentence, size_t sentence_len, unsigned char **normalized, size_t *normalized_len, size_t **offsets, size_t *offsets_len) {
auto sentence_view = absl::string_view(sentence, sentence_len);
std::string normalized_str;
std::vector<size_t> norm_to_orig_vec;
auto status = spp->Normalize(sentence_view, &normalized_str, &norm_to_orig_vec);

if (!status.ok()) {
*normalized = nullptr;
*normalized_len = 0;
*offsets = nullptr;
*offsets_len = 0;
return to_underlying_type(status.code());
}

// Allocate and copy normalized string
*normalized_len = normalized_str.size();
*normalized = (unsigned char *) malloc(normalized_str.size());
if (*normalized == nullptr && normalized_str.size() > 0) {
// Allocation failed
*normalized_len = 0;
*offsets = nullptr;
*offsets_len = 0;
return to_underlying_type(sentencepiece::util::StatusCode::kResourceExhausted);
}
memcpy(*normalized, normalized_str.data(), normalized_str.size());

// Allocate and copy offsets
*offsets_len = norm_to_orig_vec.size();
*offsets = (size_t *) malloc(norm_to_orig_vec.size() * sizeof(size_t));
if (*offsets == nullptr && norm_to_orig_vec.size() > 0) {
// Allocation failed - free the already allocated normalized string
free(*normalized);
*normalized = nullptr;
*normalized_len = 0;
*offsets_len = 0;
return to_underlying_type(sentencepiece::util::StatusCode::kResourceExhausted);
}
memcpy(*offsets, norm_to_orig_vec.data(), norm_to_orig_vec.size() * sizeof(size_t));

return to_underlying_type(status.code());
}

int spp_eos_id(SentencePieceProcessor *spp) {
return spp->eos_id();
}
Expand Down
2 changes: 2 additions & 0 deletions sentencepiece-sys/src/ffi/sentencepiece.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ int spp_piece_size(SentencePieceProcessor *spp);

int spp_unk_id(SentencePieceProcessor *spp);

int spp_normalize(SentencePieceProcessor *spp, char const *sentence, size_t sentence_len, unsigned char **normalized, size_t *normalized_len);

#ifdef __cplusplus
}
#endif
140 changes: 129 additions & 11 deletions sentencepiece/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ use thiserror::Error;

use sentencepiece_sys::{
spp_bos_id, spp_decode_piece_ids, spp_decode_pieces, spp_encode_as_serialized_proto,
spp_eos_id, spp_free, spp_from_serialized_proto, spp_is_unknown, spp_load, spp_new, spp_pad_id,
spp_piece_size, spp_piece_to_id, spp_sample_encode_as_serialized_proto,
spp_to_serialized_proto, spp_unk_id, SentencePieceProcessor as CSentencePieceProcessor,
spp_eos_id, spp_free, spp_from_serialized_proto, spp_is_unknown, spp_load, spp_new,
spp_normalize, spp_normalize_with_offsets, spp_pad_id, spp_piece_size, spp_piece_to_id,
spp_sample_encode_as_serialized_proto, spp_to_serialized_proto, spp_unk_id,
SentencePieceProcessor as CSentencePieceProcessor,
};

mod sentencepiece;
Expand Down Expand Up @@ -111,20 +112,20 @@ pub enum CSentencePieceError {
}

/// Small wrapper struct to deallocate data automatically.
struct CData {
data: *const u8,
struct CData<T> {
data: *const T,
len: usize,
}

impl Deref for CData {
type Target = [u8];
impl<T> Deref for CData<T> {
type Target = [T];

fn deref(&self) -> &Self::Target {
unsafe { slice::from_raw_parts(self.data, self.len) }
}
}

impl Drop for CData {
impl<T> Drop for CData<T> {
fn drop(&mut self) {
unsafe { libc::free(self.data as *mut c_void) }
}
Expand Down Expand Up @@ -233,7 +234,7 @@ impl SentencePieceProcessor {
)
};

let c_str = CData {
let c_str: CData<u8> = CData {
data: decoded,
len: decoded_len,
};
Expand Down Expand Up @@ -276,7 +277,7 @@ impl SentencePieceProcessor {
)
};

let c_str = CData {
let c_str: CData<u8> = CData {
data: decoded,
len: decoded_len,
};
Expand Down Expand Up @@ -329,6 +330,100 @@ impl SentencePieceProcessor {
len as usize
}

/// Normalize a sentence.
///
/// This applies the normalization rules specified in the model.
pub fn normalize(&self, sentence: &str) -> Result<String, SentencePieceError> {
let mut normalized = std::ptr::null_mut::<u8>();
let mut normalized_len = 0;

let status = unsafe {
spp_normalize(
self.inner,
sentence.as_ptr() as *const c_char,
sentence.len(),
&mut normalized,
&mut normalized_len,
)
};

let c_str: CData<u8> = CData {
data: normalized,
len: normalized_len,
};

if status == 0 {
let normalized_string = String::from_utf8(c_str.to_owned())
.expect("Normalized sentence is not UTF-8, please report this bug.");

Ok(normalized_string)
} else {
let c_error = match FromPrimitive::from_i32(status) {
Some(error) => error,
None => unreachable!(),
};
Err(SentencePieceError::CError(c_error))
}
}

/// Normalize a sentence.
///
/// This applies the normalization rules specified in the model. Returns a
/// tuple containing the normalized string and a vector mapping each byte in
/// the normalized input to the corresponding byte position in the original
/// input.
pub fn normalize_with_offsets(
&self,
sentence: &str,
) -> Result<(String, Vec<usize>), SentencePieceError> {
let mut normalized = std::ptr::null_mut::<u8>();
let mut normalized_len = 0;
let mut offsets = std::ptr::null_mut::<usize>();
let mut offsets_len = 0;

let status = unsafe {
spp_normalize_with_offsets(
self.inner,
sentence.as_ptr() as *const c_char,
sentence.len(),
&mut normalized,
&mut normalized_len,
&mut offsets,
&mut offsets_len,
)
};

// Create wrapper structs for the allocated memory to manage their lifetimes
let c_normalized: CData<u8> = CData {
data: normalized,
len: normalized_len,
};

// Create a scope to ensure c_offsets is dropped after we're done with it

{
let c_offsets: CData<usize> = CData {
data: offsets,
len: offsets_len,
};

if status == 0 {
let normalized_string = String::from_utf8(c_normalized.to_owned())
.expect("Normalized sentence is not UTF-8, please report this bug.");

let offsets_vec = c_offsets.to_vec();

Ok((normalized_string, offsets_vec))
} else {
let c_error = match FromPrimitive::from_i32(status) {
Some(error) => error,
None => unreachable!(),
};
Err(SentencePieceError::CError(c_error))
}
}
}

pub fn pad_id(&self) -> Option<u32> {
let pad_id = unsafe { spp_pad_id(self.inner) };
if pad_id < 0 {
Expand All @@ -350,7 +445,7 @@ impl SentencePieceProcessor {
}
}

fn process_encode_protobuf(c_proto: CData) -> Result<Vec<PieceWithId>, SentencePieceError> {
fn process_encode_protobuf(c_proto: CData<u8>) -> Result<Vec<PieceWithId>, SentencePieceError> {
// Errors are communicated as empty data.
if c_proto.is_empty() {
return Err(SentencePieceError::EncodeError);
Expand Down Expand Up @@ -697,6 +792,29 @@ mod tests {
let protobuf_roundtrip = spp.to_serialized_proto();
assert_eq!(protobuf, protobuf_roundtrip);
}

#[test]
fn normalizes_sentence_with_toy_model() {
let model = toy_model().unwrap();
assert_eq!(
model.normalize("KADOKAWA ABC ").unwrap(),
"▁KADOKAWA▁ABC"
);
}

#[test]
fn test_normalize_with_offsets() {
let spp = SentencePieceProcessor::open("testdata/toy.model").unwrap();

let (normalized, offsets) = spp
.normalize_with_offsets("KADOKAWA ABC ")
.unwrap();
assert_eq!(normalized, "▁KADOKAWA▁ABC");
assert_eq!(
offsets,
vec![0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 24, 24, 27, 28, 29, 30]
);
}
}

#[cfg(feature = "albert-tests")]
Expand Down
Loading