vChewing-macOS/Packages/vChewing_LangModelAssembly/Sources/LangModelAssembly/SubLMs/lmUserOverride.swift

362 lines
13 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// (c) 2021 and onwards The vChewing Project (MIT-NTL License).
// Refactored from the Cpp version of this class by Lukhnos Liu (MIT License).
// ====================
// This code is released under the MIT license (SPDX-License-Identifier: MIT)
// ... with NTL restriction stating that:
// No trademark license is granted to use the trade names, trademarks, service
// marks, or product names of Contributor, except as required to fulfill notice
// requirements defined in MIT License.
import Foundation
import Megrez
import Shared
extension vChewingLM {
public class LMUserOverride {
// MARK: - Main
var mutCapacity: Int
var mutDecayExponent: Double
var mutLRUList: [KeyObservationPair] = []
var mutLRUMap: [String: KeyObservationPair] = [:]
let kDecayThreshold: Double = 1.0 / 1_048_576.0 //
var fileSaveLocationURL: URL
public static let kObservedOverrideHalfLife: Double = 3600.0 * 6 // 6
public init(capacity: Int = 500, decayConstant: Double = LMUserOverride.kObservedOverrideHalfLife, dataURL: URL) {
mutCapacity = max(capacity, 1) // Ensures that this integer value is always > 0.
mutDecayExponent = log(0.5) / decayConstant
fileSaveLocationURL = dataURL
}
public func performObservation(
walkedBefore: [Megrez.Compositor.Node], walkedAfter: [Megrez.Compositor.Node],
cursor: Int, timestamp: Double, saveCallback: @escaping () -> Void
) {
//
guard !walkedAfter.isEmpty, !walkedBefore.isEmpty else { return }
guard walkedBefore.totalKeyCount == walkedAfter.totalKeyCount else { return }
//
var actualCursor = 0
guard let currentNode = walkedAfter.findNode(at: cursor, target: &actualCursor) else { return }
// 使
guard currentNode.spanLength <= 3 else { return }
//
guard actualCursor > 0 else { return } //
let currentNodeIndex = actualCursor
actualCursor -= 1
var prevNodeIndex = 0
guard let prevNode = walkedBefore.findNode(at: actualCursor, target: &prevNodeIndex) else { return }
let forceHighScoreOverride: Bool = currentNode.spanLength > prevNode.spanLength
let breakingUp = currentNode.spanLength == 1 && prevNode.spanLength > 1
let targetNodeIndex = breakingUp ? currentNodeIndex : prevNodeIndex
let key: String = vChewingLM.LMUserOverride.formObservationKey(
walkedNodes: walkedAfter, headIndex: targetNodeIndex
)
guard !key.isEmpty else { return }
doObservation(
key: key, candidate: currentNode.currentUnigram.value, timestamp: timestamp,
forceHighScoreOverride: forceHighScoreOverride, saveCallback: { saveCallback() }
)
}
public func fetchSuggestion(
currentWalk: [Megrez.Compositor.Node], cursor: Int, timestamp: Double
) -> Suggestion {
var headIndex = 0
guard let nodeIter = currentWalk.findNode(at: cursor, target: &headIndex) else { return .init() }
let key = vChewingLM.LMUserOverride.formObservationKey(walkedNodes: currentWalk, headIndex: headIndex)
return getSuggestion(key: key, timestamp: timestamp, headReading: nodeIter.key)
}
}
}
// MARK: - Private Structures
extension vChewingLM.LMUserOverride {
enum OverrideUnit: CodingKey { case count, timestamp }
enum ObservationUnit: CodingKey { case count, overrides }
enum KeyObservationPairUnit: CodingKey { case key, observation }
struct Override: Hashable, Encodable, Decodable {
var count: Int = 0
var timestamp: Double = 0.0
var forceHighScoreOverride = false
static func == (lhs: Override, rhs: Override) -> Bool {
lhs.count == rhs.count && lhs.timestamp == rhs.timestamp
}
func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: OverrideUnit.self)
try container.encode(timestamp, forKey: .timestamp)
try container.encode(count, forKey: .count)
}
func hash(into hasher: inout Hasher) {
hasher.combine(count)
hasher.combine(timestamp)
}
}
struct Observation: Hashable, Encodable, Decodable {
var count: Int = 0
var overrides: [String: Override] = [:]
static func == (lhs: Observation, rhs: Observation) -> Bool {
lhs.count == rhs.count && lhs.overrides == rhs.overrides
}
func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: ObservationUnit.self)
try container.encode(count, forKey: .count)
try container.encode(overrides, forKey: .overrides)
}
func hash(into hasher: inout Hasher) {
hasher.combine(count)
hasher.combine(overrides)
}
mutating func update(candidate: String, timestamp: Double, forceHighScoreOverride: Bool = false) {
count += 1
if overrides.keys.contains(candidate) {
overrides[candidate]?.timestamp = timestamp
overrides[candidate]?.count += 1
overrides[candidate]?.forceHighScoreOverride = forceHighScoreOverride
} else {
overrides[candidate] = .init(count: 1, timestamp: timestamp)
}
}
}
struct KeyObservationPair: Hashable, Encodable, Decodable {
var key: String
var observation: Observation
static func == (lhs: KeyObservationPair, rhs: KeyObservationPair) -> Bool {
lhs.key == rhs.key && lhs.observation == rhs.observation
}
func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: KeyObservationPairUnit.self)
try container.encode(key, forKey: .key)
try container.encode(observation, forKey: .observation)
}
func hash(into hasher: inout Hasher) {
hasher.combine(key)
hasher.combine(observation)
}
}
}
// MARK: - Hash and Dehash the entire UOM data
extension vChewingLM.LMUserOverride {
/// LRU
public func bleachUnigrams(saveCallback: @escaping () -> Void) {
for key in mutLRUMap.keys {
if !key.contains("(),()") { continue }
mutLRUMap.removeValue(forKey: key)
}
resetMRUList()
saveCallback()
}
internal func resetMRUList() {
mutLRUList.removeAll()
for neta in mutLRUMap.reversed() {
mutLRUList.append(neta.value)
}
}
public func clearData(withURL fileURL: URL) {
mutLRUMap = .init()
mutLRUList = .init()
do {
let nullData = "{}"
try nullData.write(to: fileURL, atomically: false, encoding: .utf8)
} catch {
vCLog("UOM Error: Unable to clear data. Details: \(error)")
return
}
}
public func saveData(toURL fileURL: URL? = nil) {
let encoder = JSONEncoder()
do {
guard let jsonData = try? encoder.encode(mutLRUMap) else { return }
let fileURL: URL = fileURL ?? fileSaveLocationURL
try jsonData.write(to: fileURL, options: .atomic)
} catch {
vCLog("UOM Error: Unable to save data, abort saving. Details: \(error)")
return
}
}
public func loadData(fromURL fileURL: URL) {
let decoder = JSONDecoder()
do {
let data = try Data(contentsOf: fileURL, options: .mappedIfSafe)
if ["", "{}"].contains(String(data: data, encoding: .utf8)) { return }
guard let jsonResult = try? decoder.decode([String: KeyObservationPair].self, from: data) else {
vCLog("UOM Error: Read file content type invalid, abort loading.")
return
}
mutLRUMap = jsonResult
resetMRUList()
} catch {
vCLog("UOM Error: Unable to read file or parse the data, abort loading. Details: \(error)")
return
}
}
public struct Suggestion {
public var candidates = [(String, Megrez.Unigram)]()
public var forceHighScoreOverride = false
public var isEmpty: Bool { candidates.isEmpty }
}
}
// MARK: - Private Methods
extension vChewingLM.LMUserOverride {
private func doObservation(
key: String, candidate: String, timestamp: Double, forceHighScoreOverride: Bool,
saveCallback: @escaping () -> Void
) {
guard mutLRUMap[key] != nil else {
var observation: Observation = .init()
observation.update(candidate: candidate, timestamp: timestamp, forceHighScoreOverride: forceHighScoreOverride)
let koPair = KeyObservationPair(key: key, observation: observation)
// key key key
// Swift
mutLRUMap.removeValue(forKey: key)
mutLRUMap[key] = koPair
mutLRUList.insert(koPair, at: 0)
if mutLRUList.count > mutCapacity {
mutLRUMap.removeValue(forKey: mutLRUList[mutLRUList.endIndex].key)
mutLRUList.removeLast()
}
vCLog("UOM: Observation finished with new observation: \(key)")
saveCallback()
return
}
// decayCallback
if var theNeta = mutLRUMap[key] {
theNeta.observation.update(
candidate: candidate, timestamp: timestamp, forceHighScoreOverride: forceHighScoreOverride
)
mutLRUList.insert(theNeta, at: 0)
mutLRUMap[key] = theNeta
vCLog("UOM: Observation finished with existing observation: \(key)")
saveCallback()
}
}
private func getSuggestion(key: String, timestamp: Double, headReading: String) -> Suggestion {
guard !key.isEmpty, let kvPair = mutLRUMap[key] else { return .init() }
let observation: Observation = kvPair.observation
var candidates: [(String, Megrez.Unigram)] = .init()
var forceHighScoreOverride = false
var currentHighScore: Double = 0
for (i, theObservation) in observation.overrides {
let overrideScore = getScore(
eventCount: theObservation.count, totalCount: observation.count,
eventTimestamp: theObservation.timestamp, timestamp: timestamp, lambda: mutDecayExponent
)
if (0...currentHighScore).contains(overrideScore) { continue }
candidates.append((headReading, .init(value: i, score: overrideScore)))
forceHighScoreOverride = theObservation.forceHighScoreOverride
currentHighScore = overrideScore
}
return .init(candidates: candidates, forceHighScoreOverride: forceHighScoreOverride)
}
private func getScore(
eventCount: Int,
totalCount: Int,
eventTimestamp: Double,
timestamp: Double,
lambda: Double
) -> Double {
let decay = exp((timestamp - eventTimestamp) * lambda)
if decay < kDecayThreshold { return 0.0 }
let prob = Double(eventCount) / Double(totalCount)
return prob * decay
}
private static func isPunctuation(_ node: Megrez.Compositor.Node) -> Bool {
for key in node.keyArray {
guard let firstChar = key.first else { continue }
return String(firstChar) == "_"
}
return false
}
private static func formObservationKey(
walkedNodes: [Megrez.Compositor.Node], headIndex cursorIndex: Int, readingOnly: Bool = false
) -> String {
let whiteList = "你他妳她祢衪它牠再在"
var arrNodes: [Megrez.Compositor.Node] = []
var intLength = 0
for theNodeAnchor in walkedNodes {
arrNodes.append(theNodeAnchor)
intLength += theNodeAnchor.spanLength
if intLength >= cursorIndex {
break
}
}
if arrNodes.isEmpty { return "" }
arrNodes = Array(arrNodes.reversed())
let kvCurrent = arrNodes[0].currentPair
guard !kvCurrent.key.contains("_") else {
return ""
}
//
if kvCurrent.key.split(separator: "-").count != kvCurrent.value.count { return "" }
//
let strCurrent = kvCurrent.key
var kvPrevious = Megrez.Compositor.KeyValuePaired()
var kvAnterior = Megrez.Compositor.KeyValuePaired()
var readingStack = ""
var trigramKey: String { "(\(kvAnterior.toNGramKey),\(kvPrevious.toNGramKey),\(strCurrent))" }
var result: String {
// kvCurrent
if readingStack.contains("_")
|| (!kvPrevious.isValid && kvCurrent.value.count == 1 && !whiteList.contains(kvCurrent.value))
{
return ""
} else {
return (readingOnly ? strCurrent : trigramKey)
}
}
if arrNodes.count >= 2,
!kvPrevious.key.contains("_"),
kvPrevious.key.split(separator: "-").count == kvPrevious.value.count
{
kvPrevious = arrNodes[1].currentPair
readingStack = kvPrevious.key + readingStack
}
if arrNodes.count >= 3,
!kvAnterior.key.contains("_"),
kvAnterior.key.split(separator: "-").count == kvAnterior.value.count
{
kvAnterior = arrNodes[2].currentPair
readingStack = kvAnterior.key + readingStack
}
return result
}
}