vChewing-macOS/Source/UserOverrideModel.cpp

203 lines
6.2 KiB
C++

//
// UserOverrideModel.cpp
//
// Copyright (c) 2017 The McBopomofo Project.
//
// Permission is hereby granted, free of charge, to any person
// obtaining a copy of this software and associated documentation
// files (the "Software"), to deal in the Software without
// restriction, including without limitation the rights to use,
// copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
//
#include "UserOverrideModel.h"
#include <cassert>
#include <cmath>
#include <sstream>
using namespace McBopomofo;
// About 20 generations.
static const double DecayThreshould = 1.0 / 1048576.0;
static double Score(size_t eventCount,
size_t totalCount,
double eventTimestamp,
double timestamp,
double lambda);
static string WalkedNodesToKey(const std::vector<NodeAnchor>& walkedNodes,
size_t cursorIndex);
UserOverrideModel::UserOverrideModel(size_t capacity, double decayConstant)
: m_capacity(capacity) {
assert(m_capacity > 0);
m_decayExponent = log(0.5) / decayConstant;
}
void UserOverrideModel::observe(const std::vector<NodeAnchor>& walkedNodes,
size_t cursorIndex,
const string& candidate,
double timestamp) {
string key = WalkedNodesToKey(walkedNodes, cursorIndex);
auto mapIter = m_lruMap.find(key);
if (mapIter == m_lruMap.end()) {
auto keyValuePair = KeyObservationPair(key, Observation());
Observation& observation = keyValuePair.second;
observation.update(candidate, timestamp);
m_lruList.push_front(keyValuePair);
auto listIter = m_lruList.begin();
auto lruKeyValue = std::pair<std::string,
std::list<KeyObservationPair>::iterator>(key, listIter);
m_lruMap.insert(lruKeyValue);
if (m_lruList.size() > m_capacity) {
auto lastKeyValuePair = m_lruList.end();
--lastKeyValuePair;
m_lruMap.erase(lastKeyValuePair->first);
m_lruList.pop_back();
}
} else {
auto listIter = mapIter->second;
m_lruList.splice(m_lruList.begin(), m_lruList, listIter);
auto& keyValuePair = *listIter;
Observation& observation = keyValuePair.second;
observation.update(candidate, timestamp);
}
}
string UserOverrideModel::suggest(const std::vector<NodeAnchor>& walkedNodes,
size_t cursorIndex,
double timestamp) {
string key = WalkedNodesToKey(walkedNodes, cursorIndex);
auto mapIter = m_lruMap.find(key);
if (mapIter == m_lruMap.end()) {
return string();
}
auto listIter = mapIter->second;
auto& keyValuePair = *listIter;
const Observation& observation = keyValuePair.second;
string candidate;
double score = 0.0;
for (auto i = observation.overrides.begin();
i != observation.overrides.end();
++i) {
const Override& o = i->second;
double overrideScore = Score(o.count,
observation.count,
o.timestamp,
timestamp,
m_decayExponent);
if (overrideScore == 0.0) {
continue;
}
if (overrideScore > score) {
candidate = i->first;
score = overrideScore;
}
}
return candidate;
}
void UserOverrideModel::Observation::update(const string& candidate,
double timestamp) {
count++;
auto& o = overrides[candidate];
o.timestamp = timestamp;
o.count++;
}
static double Score(size_t eventCount,
size_t totalCount,
double eventTimestamp,
double timestamp,
double lambda) {
double decay = exp((timestamp - eventTimestamp) * lambda);
if (decay < DecayThreshould) {
return 0.0;
}
double prob = (double)eventCount / (double)totalCount;
return prob * decay;
}
static string WalkedNodesToKey(const std::vector<NodeAnchor>& walkedNodes,
size_t cursorIndex) {
std::stringstream s;
std::vector<NodeAnchor> n;
size_t ll = 0;
for (std::vector<NodeAnchor>::const_iterator i = walkedNodes.begin();
i != walkedNodes.end();
++i) {
const auto& nn = *i;
n.push_back(nn);
ll += nn.spanningLength;
if (ll >= cursorIndex) {
break;
}
}
std::vector<NodeAnchor>::const_reverse_iterator r = n.rbegin();
if (r == n.rend()) {
return "";
}
string current = (*r).node->currentKeyValue().key;
++r;
s.clear();
s.str(std::string());
if (r != n.rend()) {
s << "("
<< (*r).node->currentKeyValue().key
<< ","
<< (*r).node->currentKeyValue().value
<< ")";
++r;
} else {
s << "()";
}
string prev = s.str();
s.clear();
s.str(std::string());
if (r != n.rend()) {
s << "("
<< (*r).node->currentKeyValue().key
<< ","
<< (*r).node->currentKeyValue().value
<< ")";
++r;
} else {
s << "()";
}
string anterior = s.str();
s.clear();
s.str(std::string());
s << "(" << anterior << "," << prev << "," << current << ")";
return s.str();
}