From bb184e286b2f986cad9b904dbe50c2d6b7f6561e Mon Sep 17 00:00:00 2001 From: ShikiSuen Date: Mon, 14 Mar 2022 12:48:40 +0800 Subject: [PATCH] LMInstantiator // Add m_userSymbolModel, etc. plus comment fix. --- Source/Modules/LangModelRelated/LMInstantiator.h | 7 ++++++- .../Modules/LangModelRelated/LMInstantiator.mm | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/Source/Modules/LangModelRelated/LMInstantiator.h b/Source/Modules/LangModelRelated/LMInstantiator.h index 0357a464..c8b0c504 100644 --- a/Source/Modules/LangModelRelated/LMInstantiator.h +++ b/Source/Modules/LangModelRelated/LMInstantiator.h @@ -27,6 +27,7 @@ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR TH #include "PhraseReplacementMap.h" #include "SymbolLM.h" #include "UserPhrasesLM.h" +#include "UserSymbolLM.h" #include #include @@ -89,8 +90,11 @@ public: /// @param userPhrasesPath The path of user phrases. /// @param excludedPhrasesPath The path of excluded phrases. void loadUserPhrases(const char* userPhrasesPath, const char* excludedPhrasesPath); + /// Asks to load the user symbol data at the given path. + /// @param userSymbolDataPath The path of user symbol data. + void loadUserSymbolData(const char* userPhrasesPath); /// Asks to load the user associated phrases at the given path. - /// @param userAssociatedPhrasesPath The path of the phrase replacement table. + /// @param userAssociatedPhrasesPath The path of the user associated phrases. void loadUserAssociatedPhrases(const char* userAssociatedPhrasesPath); /// Asks to load the phrase replacement table at the given path. /// @param phraseReplacementPath The path of the phrase replacement table. @@ -150,6 +154,7 @@ protected: CNSLM m_cnsModel; UserPhrasesLM m_userPhrases; UserPhrasesLM m_excludedPhrases; + UserSymbolLM m_userSymbolModel; PhraseReplacementMap m_phraseReplacement; AssociatedPhrases m_associatedPhrases; bool m_phraseReplacementEnabled; diff --git a/Source/Modules/LangModelRelated/LMInstantiator.mm b/Source/Modules/LangModelRelated/LMInstantiator.mm index 948f9aff..870468c8 100644 --- a/Source/Modules/LangModelRelated/LMInstantiator.mm +++ b/Source/Modules/LangModelRelated/LMInstantiator.mm @@ -32,6 +32,7 @@ LMInstantiator::~LMInstantiator() m_languageModel.close(); m_miscModel.close(); m_userPhrases.close(); + m_userSymbolModel.close(); m_cnsModel.close(); m_excludedPhrases.close(); m_phraseReplacement.close(); @@ -103,6 +104,14 @@ void LMInstantiator::loadUserPhrases(const char* userPhrasesDataPath, } } +void LMInstantiator::loadUserSymbolData(const char *userSymbolDataPath) +{ + if (userSymbolDataPath) { + m_userSymbolModel.close(); + m_userSymbolModel.open(userSymbolDataPath); + } +} + void LMInstantiator::loadUserAssociatedPhrases(const char *userAssociatedPhrasesPath) { if (userAssociatedPhrasesPath) { @@ -140,6 +149,7 @@ const std::vector LMInstantiator::unigramsForKey(const std std::vector miscUnigrams; std::vector symbolUnigrams; std::vector userUnigrams; + std::vector userSymbolUnigrams; std::vector cnsUnigrams; std::unordered_set excludedValues; @@ -175,6 +185,11 @@ const std::vector LMInstantiator::unigramsForKey(const std symbolUnigrams = filterAndTransformUnigrams(rawSymbolUnigrams, excludedValues, insertedValues); } + if (m_userSymbolModel.hasUnigramsForKey(key) && m_symbolEnabled) { + std::vector rawUserSymbolUnigrams = m_userSymbolModel.unigramsForKey(key); + userSymbolUnigrams = filterAndTransformUnigrams(rawUserSymbolUnigrams, excludedValues, insertedValues); + } + if (m_cnsModel.hasUnigramsForKey(key) && m_cnsEnabled) { std::vector rawCNSUnigrams = m_cnsModel.unigramsForKey(key); cnsUnigrams = filterAndTransformUnigrams(rawCNSUnigrams, excludedValues, insertedValues); @@ -183,6 +198,7 @@ const std::vector LMInstantiator::unigramsForKey(const std allUnigrams.insert(allUnigrams.begin(), userUnigrams.begin(), userUnigrams.end()); allUnigrams.insert(allUnigrams.end(), cnsUnigrams.begin(), cnsUnigrams.end()); allUnigrams.insert(allUnigrams.begin(), miscUnigrams.begin(), miscUnigrams.end()); + allUnigrams.insert(allUnigrams.end(), userSymbolUnigrams.begin(), userSymbolUnigrams.end()); allUnigrams.insert(allUnigrams.end(), symbolUnigrams.begin(), symbolUnigrams.end()); return allUnigrams; }