From 26f0daceb784b930e6e965337491c8f47ab4527d Mon Sep 17 00:00:00 2001 From: Filipp Ozinov Date: Fri, 13 Apr 2018 20:11:10 +0300 Subject: [PATCH 1/3] Add new words at runtime - added support to lang model --- .gitignore | 1 + jamspell/lang_model.cpp | 60 ++++++++++++++++++++++++++++++++++++----- jamspell/lang_model.hpp | 7 +++++ jamspell/utils.cpp | 8 ++++-- jamspell/utils.hpp | 1 + 5 files changed, 69 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index c76d3e8..9193298 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ *.tar.gz *.spell *.cache +.DS_Store diff --git a/jamspell/lang_model.cpp b/jamspell/lang_model.cpp index addd87c..f344428 100644 --- a/jamspell/lang_model.cpp +++ b/jamspell/lang_model.cpp @@ -260,6 +260,7 @@ bool TLangModel::Load(const std::string& modelFileName) { for (auto&& it: WordToId) { IdToWord[it.second] = &it.first; } + BaseModelLastWordID = LastWordID; return true; } @@ -328,6 +329,35 @@ uint64_t TLangModel::GetCheckSum() const { return CheckSum; } +template +inline void IncCount(TRuntimeModelCounts& counts, const T& key, TCount value) { + std::string tmp = DumpKey(key); + uint32_t cityHash32 = CityHash32(&tmp[0], tmp.size()); + counts[cityHash32] += value; +} + +void TLangModel::AddTextFragment(const std::wstring& text, uint32_t count) { + std::wstring trainText = text; + ToLower(trainText); + TSentences sentences = Tokenizer.Process(trainText); + TIdSentences sentenceIds = ConvertToIds(sentences); + + for (size_t i = 0; i < sentenceIds.size(); ++i) { + const TWordIds& words = sentenceIds[i]; + for (auto w: words) { + IncCount(RuntimeModelCounts, w, count); + } + for (ssize_t j = 0; j < (ssize_t)words.size() - 1; ++j) { + TGram2Key key(words[j], words[j+1]); + IncCount(RuntimeModelCounts, key, count); + } + for (ssize_t j = 0; j < (ssize_t)words.size() - 2; ++j) { + TGram3Key key(words[j], words[j+1], words[j+2]); + IncCount(RuntimeModelCounts, key, count); + } + } +} + TWord TLangModel::GetWord(const std::wstring& word) const { auto it = WordToId.find(word); if (it != WordToId.end()) { @@ -372,10 +402,18 @@ double TLangModel::GetGram3Prob(TWordId word1, TWordId word2, TWordId word3) con return countsGram3 / countsGram2; } +enum ECheckPolicy { + CP_Both = 0, + CP_Base = 1, + CP_Runtime = 2, +}; + template TCount GetGramHashCount(T key, const TPerfectHash& ph, - const std::vector>& buckets) + const TRuntimeModelCounts& runtimeModelCounts, + const std::vector>& buckets, + ECheckPolicy checkPolicy) { constexpr int TMP_BUF_SIZE = 128; static char tmpBuff[TMP_BUF_SIZE]; @@ -392,8 +430,17 @@ TCount GetGramHashCount(T key, const std::pair& data = buckets[bucket]; TCount res = TCount(); - if (data.first == CityHash16(tmpBuff, tmpBuffStream.Size())) { - res = UnpackInt32(data.second); + uint32_t cityHash32 = CityHash32(tmpBuff, tmpBuffStream.Size()); + uint16_t cityHash16 = cityHash32 % std::numeric_limits::max(); + if (checkPolicy == CP_Base || checkPolicy == CP_Both) { + if (data.first == cityHash16) { + res += UnpackInt32(data.second); + } + } + if (checkPolicy == CP_Runtime || checkPolicy == CP_Both) { + auto it = runtimeModelCounts.find(cityHash32); + assert(checkPolicy != CP_Runtime || it != runtimeModelCounts.end()); + res += it->second; } return res; } @@ -403,7 +450,8 @@ TCount TLangModel::GetGram1HashCount(TWordId word) const { return TCount(); } TGram1Key key = word; - return GetGramHashCount(key, PerfectHash, Buckets); + ECheckPolicy policy = key >= BaseModelLastWordID ? CP_Runtime : CP_Base; + return GetGramHashCount(key, PerfectHash, RuntimeModelCounts, Buckets, policy); } TCount TLangModel::GetGram2HashCount(TWordId word1, TWordId word2) const { @@ -411,7 +459,7 @@ TCount TLangModel::GetGram2HashCount(TWordId word1, TWordId word2) const { return TCount(); } TGram2Key key({word1, word2}); - return GetGramHashCount(key, PerfectHash, Buckets); + return GetGramHashCount(key, PerfectHash, RuntimeModelCounts, Buckets, CP_Both); } TCount TLangModel::GetGram3HashCount(TWordId word1, TWordId word2, TWordId word3) const { @@ -419,7 +467,7 @@ TCount TLangModel::GetGram3HashCount(TWordId word1, TWordId word2, TWordId word3 return TCount(); } TGram3Key key(word1, word2, word3); - return GetGramHashCount(key, PerfectHash, Buckets); + return GetGramHashCount(key, PerfectHash, RuntimeModelCounts, Buckets, CP_Both); } } // NJamSpell diff --git a/jamspell/lang_model.hpp b/jamspell/lang_model.hpp index 3e1d0f1..e53e4a2 100644 --- a/jamspell/lang_model.hpp +++ b/jamspell/lang_model.hpp @@ -55,6 +55,8 @@ class TRobinHash: public tsl::robin_map { } }; +using TRuntimeModelCounts = tsl::robin_map; + class TLangModel { public: bool Train(const std::string& fileName, const std::string& alphabetFile); @@ -77,6 +79,8 @@ class TLangModel { uint64_t GetCheckSum() const; + void AddTextFragment(const std::wstring& text, uint32_t count = 1); + HANDYPACK(WordToId, LastWordID, TotalWords, VocabSize, PerfectHash, Buckets, Tokenizer, CheckSum) private: @@ -102,6 +106,9 @@ class TLangModel { std::vector> Buckets; TPerfectHash PerfectHash; uint64_t CheckSum; + + TWordId BaseModelLastWordID = 0; + TRuntimeModelCounts RuntimeModelCounts; }; diff --git a/jamspell/utils.cpp b/jamspell/utils.cpp index 925d386..8bb20af 100644 --- a/jamspell/utils.cpp +++ b/jamspell/utils.cpp @@ -152,13 +152,17 @@ wchar_t MakeUpperIfRequired(wchar_t orig, wchar_t sample) { } uint16_t CityHash16(const std::string& str) { - uint32_t hash = CityHash32(&str[0], str.size()); + uint32_t hash = ::CityHash32(&str[0], str.size()); return hash % std::numeric_limits::max(); } uint16_t CityHash16(const char* str, size_t size) { - uint32_t hash = CityHash32(str, size); + uint32_t hash = ::CityHash32(str, size); return hash % std::numeric_limits::max(); } +uint32_t CityHash32(const char *str, size_t size) { + return ::CityHash32(str, size); +} + } // NJamSpell diff --git a/jamspell/utils.hpp b/jamspell/utils.hpp index 882714d..df0de27 100644 --- a/jamspell/utils.hpp +++ b/jamspell/utils.hpp @@ -62,5 +62,6 @@ void ToLower(std::wstring& text); wchar_t MakeUpperIfRequired(wchar_t orig, wchar_t sample); uint16_t CityHash16(const std::string& str); uint16_t CityHash16(const char* str, size_t size); +uint32_t CityHash32(const char* str, size_t size); } // NJamSpell From 5168b5c49bfa2f0463e1ea939e5852e0866c6c5d Mon Sep 17 00:00:00 2001 From: Filipp Ozinov Date: Fri, 13 Apr 2018 23:56:01 +0300 Subject: [PATCH 2/3] Fixed tests --- jamspell/lang_model.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jamspell/lang_model.cpp b/jamspell/lang_model.cpp index f344428..796e119 100644 --- a/jamspell/lang_model.cpp +++ b/jamspell/lang_model.cpp @@ -437,10 +437,13 @@ TCount GetGramHashCount(T key, res += UnpackInt32(data.second); } } + if (checkPolicy == CP_Runtime || checkPolicy == CP_Both) { auto it = runtimeModelCounts.find(cityHash32); assert(checkPolicy != CP_Runtime || it != runtimeModelCounts.end()); - res += it->second; + if (it != runtimeModelCounts.end()) { + res += it->second; + } } return res; } From 33ff71b5ee57480ce89547bf447e0d94ff3dd1d4 Mon Sep 17 00:00:00 2001 From: Filipp Ozinov Date: Sat, 14 Apr 2018 00:10:31 +0300 Subject: [PATCH 3/3] Indexing runtime words --- jamspell/lang_model.cpp | 9 +++++++++ jamspell/lang_model.hpp | 3 +++ jamspell/spell_corrector.cpp | 16 ++++++++++++++++ jamspell/spell_corrector.hpp | 1 + 4 files changed, 29 insertions(+) diff --git a/jamspell/lang_model.cpp b/jamspell/lang_model.cpp index 796e119..6bff761 100644 --- a/jamspell/lang_model.cpp +++ b/jamspell/lang_model.cpp @@ -321,10 +321,19 @@ TWord TLangModel::GetWordById(TWordId wid) const { return TWord(*IdToWord[wid]); } +const std::wstring& TLangModel::GetWstrById(TWordId wid) const { + assert(wid < IdToWord.size()); + return *IdToWord[wid]; +} + TCount TLangModel::GetWordCount(TWordId wid) const { return GetGram1HashCount(wid); } +TWordId TLangModel::GetLastWordID() const { + return LastWordID; +} + uint64_t TLangModel::GetCheckSum() const { return CheckSum; } diff --git a/jamspell/lang_model.hpp b/jamspell/lang_model.hpp index e53e4a2..5a56fe8 100644 --- a/jamspell/lang_model.hpp +++ b/jamspell/lang_model.hpp @@ -75,8 +75,11 @@ class TLangModel { TWordId GetWordId(const TWord& word); TWordId GetWordIdNoCreate(const TWord& word) const; TWord GetWordById(TWordId wid) const; + const std::wstring& GetWstrById(TWordId wid) const; TCount GetWordCount(TWordId wid) const; + TWordId GetLastWordID() const; + uint64_t GetCheckSum() const; void AddTextFragment(const std::wstring& text, uint32_t count = 1); diff --git a/jamspell/spell_corrector.cpp b/jamspell/spell_corrector.cpp index 6a0b3d1..480d9f0 100644 --- a/jamspell/spell_corrector.cpp +++ b/jamspell/spell_corrector.cpp @@ -254,6 +254,22 @@ void TSpellCorrector::SetMaxCandiatesToCheck(size_t maxCandidatesToCheck) { MaxCandiatesToCheck = maxCandidatesToCheck; } +void TSpellCorrector::AddTextFragment(const std::wstring& text, uint32_t count) { + TWordId startWordID = LangModel.GetLastWordID(); + LangModel.AddTextFragment(text, count); + TWordId endWordID = LangModel.GetLastWordID(); + for (TWordId wid = startWordID; wid < endWordID; ++wid) { + const std::wstring& w = LangModel.GetWstrById(wid); + auto deletes = GetDeletes2(w); + for (auto&& w1: deletes) { + Deletes1->Insert(WideToUTF8(w1.back())); + for (size_t i = 0; i < w1.size() - 1; ++i) { + Deletes2->Insert(WideToUTF8(w1[i])); + } + } + } +} + template inline void AddVec(T& target, const T& source) { target.insert(target.end(), source.begin(), source.end()); diff --git a/jamspell/spell_corrector.hpp b/jamspell/spell_corrector.hpp index 6c60770..c3b0d14 100644 --- a/jamspell/spell_corrector.hpp +++ b/jamspell/spell_corrector.hpp @@ -18,6 +18,7 @@ class TSpellCorrector { std::wstring FixFragmentNormalized(const std::wstring& text) const; void SetPenalty(double knownWordsPenaly, double unknownWordsPenalty); void SetMaxCandiatesToCheck(size_t maxCandidatesToCheck); + void AddTextFragment(const std::wstring& text, uint32_t count = 1); private: void FilterCandidatesByFrequency(std::unordered_set& uniqueCandidates, NJamSpell::TWord origWord) const; NJamSpell::TWords Edits(const NJamSpell::TWord& word) const;