LMInstantiator // Add m_userSymbolModel, etc. plus comment fix.

This commit is contained in:
ShikiSuen 2022-03-14 12:48:40 +08:00
parent 15baf6d960
commit 89357c1b91
2 changed files with 22 additions and 1 deletions

View File

@ -27,6 +27,7 @@ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR TH
#include "PhraseReplacementMap.h" #include "PhraseReplacementMap.h"
#include "SymbolLM.h" #include "SymbolLM.h"
#include "UserPhrasesLM.h" #include "UserPhrasesLM.h"
#include "UserSymbolLM.h"
#include <stdio.h> #include <stdio.h>
#include <unordered_set> #include <unordered_set>
@ -89,8 +90,11 @@ public:
/// @param userPhrasesPath The path of user phrases. /// @param userPhrasesPath The path of user phrases.
/// @param excludedPhrasesPath The path of excluded phrases. /// @param excludedPhrasesPath The path of excluded phrases.
void loadUserPhrases(const char* userPhrasesPath, const char* excludedPhrasesPath); void loadUserPhrases(const char* userPhrasesPath, const char* excludedPhrasesPath);
/// Asks to load the user symbol data at the given path.
/// @param userSymbolDataPath The path of user symbol data.
void loadUserSymbolData(const char* userPhrasesPath);
/// Asks to load the user associated phrases at the given path. /// Asks to load the user associated phrases at the given path.
/// @param userAssociatedPhrasesPath The path of the phrase replacement table. /// @param userAssociatedPhrasesPath The path of the user associated phrases.
void loadUserAssociatedPhrases(const char* userAssociatedPhrasesPath); void loadUserAssociatedPhrases(const char* userAssociatedPhrasesPath);
/// Asks to load the phrase replacement table at the given path. /// Asks to load the phrase replacement table at the given path.
/// @param phraseReplacementPath The path of the phrase replacement table. /// @param phraseReplacementPath The path of the phrase replacement table.
@ -150,6 +154,7 @@ protected:
CNSLM m_cnsModel; CNSLM m_cnsModel;
UserPhrasesLM m_userPhrases; UserPhrasesLM m_userPhrases;
UserPhrasesLM m_excludedPhrases; UserPhrasesLM m_excludedPhrases;
UserSymbolLM m_userSymbolModel;
PhraseReplacementMap m_phraseReplacement; PhraseReplacementMap m_phraseReplacement;
AssociatedPhrases m_associatedPhrases; AssociatedPhrases m_associatedPhrases;
bool m_phraseReplacementEnabled; bool m_phraseReplacementEnabled;

View File

@ -32,6 +32,7 @@ LMInstantiator::~LMInstantiator()
m_languageModel.close(); m_languageModel.close();
m_miscModel.close(); m_miscModel.close();
m_userPhrases.close(); m_userPhrases.close();
m_userSymbolModel.close();
m_cnsModel.close(); m_cnsModel.close();
m_excludedPhrases.close(); m_excludedPhrases.close();
m_phraseReplacement.close(); m_phraseReplacement.close();
@ -103,6 +104,14 @@ void LMInstantiator::loadUserPhrases(const char* userPhrasesDataPath,
} }
} }
void LMInstantiator::loadUserSymbolData(const char *userSymbolDataPath)
{
if (userSymbolDataPath) {
m_userSymbolModel.close();
m_userSymbolModel.open(userSymbolDataPath);
}
}
void LMInstantiator::loadUserAssociatedPhrases(const char *userAssociatedPhrasesPath) void LMInstantiator::loadUserAssociatedPhrases(const char *userAssociatedPhrasesPath)
{ {
if (userAssociatedPhrasesPath) { if (userAssociatedPhrasesPath) {
@ -140,6 +149,7 @@ const std::vector<Gramambular::Unigram> LMInstantiator::unigramsForKey(const std
std::vector<Gramambular::Unigram> miscUnigrams; std::vector<Gramambular::Unigram> miscUnigrams;
std::vector<Gramambular::Unigram> symbolUnigrams; std::vector<Gramambular::Unigram> symbolUnigrams;
std::vector<Gramambular::Unigram> userUnigrams; std::vector<Gramambular::Unigram> userUnigrams;
std::vector<Gramambular::Unigram> userSymbolUnigrams;
std::vector<Gramambular::Unigram> cnsUnigrams; std::vector<Gramambular::Unigram> cnsUnigrams;
std::unordered_set<std::string> excludedValues; std::unordered_set<std::string> excludedValues;
@ -175,6 +185,11 @@ const std::vector<Gramambular::Unigram> LMInstantiator::unigramsForKey(const std
symbolUnigrams = filterAndTransformUnigrams(rawSymbolUnigrams, excludedValues, insertedValues); symbolUnigrams = filterAndTransformUnigrams(rawSymbolUnigrams, excludedValues, insertedValues);
} }
if (m_userSymbolModel.hasUnigramsForKey(key) && m_symbolEnabled) {
std::vector<Gramambular::Unigram> rawUserSymbolUnigrams = m_userSymbolModel.unigramsForKey(key);
userSymbolUnigrams = filterAndTransformUnigrams(rawUserSymbolUnigrams, excludedValues, insertedValues);
}
if (m_cnsModel.hasUnigramsForKey(key) && m_cnsEnabled) { if (m_cnsModel.hasUnigramsForKey(key) && m_cnsEnabled) {
std::vector<Gramambular::Unigram> rawCNSUnigrams = m_cnsModel.unigramsForKey(key); std::vector<Gramambular::Unigram> rawCNSUnigrams = m_cnsModel.unigramsForKey(key);
cnsUnigrams = filterAndTransformUnigrams(rawCNSUnigrams, excludedValues, insertedValues); cnsUnigrams = filterAndTransformUnigrams(rawCNSUnigrams, excludedValues, insertedValues);
@ -183,6 +198,7 @@ const std::vector<Gramambular::Unigram> LMInstantiator::unigramsForKey(const std
allUnigrams.insert(allUnigrams.begin(), userUnigrams.begin(), userUnigrams.end()); allUnigrams.insert(allUnigrams.begin(), userUnigrams.begin(), userUnigrams.end());
allUnigrams.insert(allUnigrams.end(), cnsUnigrams.begin(), cnsUnigrams.end()); allUnigrams.insert(allUnigrams.end(), cnsUnigrams.begin(), cnsUnigrams.end());
allUnigrams.insert(allUnigrams.begin(), miscUnigrams.begin(), miscUnigrams.end()); allUnigrams.insert(allUnigrams.begin(), miscUnigrams.begin(), miscUnigrams.end());
allUnigrams.insert(allUnigrams.end(), userSymbolUnigrams.begin(), userSymbolUnigrams.end());
allUnigrams.insert(allUnigrams.end(), symbolUnigrams.begin(), symbolUnigrams.end()); allUnigrams.insert(allUnigrams.end(), symbolUnigrams.begin(), symbolUnigrams.end());
return allUnigrams; return allUnigrams;
} }