vChewing-macOS/Packages/vChewing_LangModelAssembly/Sources/LangModelAssembly/SubLMs/lmUserOverride.swift

409 lines
15 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// (c) 2021 and onwards The vChewing Project (MIT-NTL License).
// Refactored from the Cpp version of this class by Lukhnos Liu (MIT License).
// ====================
// This code is released under the MIT license (SPDX-License-Identifier: MIT)
// ... with NTL restriction stating that:
// No trademark license is granted to use the trade names, trademarks, service
// marks, or product names of Contributor, except as required to fulfill notice
// requirements defined in MIT License.
import Foundation
import Megrez
// MARK: - Public Types.
public extension LMAssembly {
struct OverrideSuggestion {
public var candidates = [(String, Megrez.Unigram)]()
public var forceHighScoreOverride = false
public var isEmpty: Bool { candidates.isEmpty }
}
}
// MARK: - LMUserOverride Class Definition.
extension LMAssembly {
class LMUserOverride {
var mutCapacity: Int
var mutDecayExponent: Double
var mutLRUList: [KeyObservationPair] = []
var mutLRUMap: [String: KeyObservationPair] = [:]
let kDecayThreshold: Double = 1.0 / 1_048_576.0 //
var fileSaveLocationURL: URL?
public static let kObservedOverrideHalfLife: Double = 3600.0 * 6 // 6
public init(capacity: Int = 500, decayConstant: Double = LMUserOverride.kObservedOverrideHalfLife, dataURL: URL? = nil) {
mutCapacity = max(capacity, 1) // Ensures that this integer value is always > 0.
mutDecayExponent = log(0.5) / decayConstant
fileSaveLocationURL = dataURL
}
}
}
// MARK: - Private Structures
extension LMAssembly.LMUserOverride {
enum OverrideUnit: CodingKey { case count, timestamp, forceHighScoreOverride }
enum ObservationUnit: CodingKey { case count, overrides }
enum KeyObservationPairUnit: CodingKey { case key, observation }
struct Override: Hashable, Encodable, Decodable {
var count: Int = 0
var timestamp: Double = 0.0
var forceHighScoreOverride = false
static func == (lhs: Override, rhs: Override) -> Bool {
lhs.count == rhs.count && lhs.timestamp == rhs.timestamp
}
func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: OverrideUnit.self)
try container.encode(count, forKey: .count)
try container.encode(timestamp, forKey: .timestamp)
try container.encode(forceHighScoreOverride, forKey: .forceHighScoreOverride)
}
func hash(into hasher: inout Hasher) {
hasher.combine(count)
hasher.combine(timestamp)
hasher.combine(forceHighScoreOverride)
}
}
struct Observation: Hashable, Encodable, Decodable {
var count: Int = 0
var overrides: [String: Override] = [:]
static func == (lhs: Observation, rhs: Observation) -> Bool {
lhs.count == rhs.count && lhs.overrides == rhs.overrides
}
func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: ObservationUnit.self)
try container.encode(count, forKey: .count)
try container.encode(overrides, forKey: .overrides)
}
func hash(into hasher: inout Hasher) {
hasher.combine(count)
hasher.combine(overrides)
}
mutating func update(candidate: String, timestamp: Double, forceHighScoreOverride: Bool = false) {
count += 1
if overrides.keys.contains(candidate) {
overrides[candidate]?.timestamp = timestamp
overrides[candidate]?.count += 1
overrides[candidate]?.forceHighScoreOverride = forceHighScoreOverride
} else {
overrides[candidate] = .init(count: 1, timestamp: timestamp)
}
}
}
struct KeyObservationPair: Hashable, Encodable, Decodable {
var key: String
var observation: Observation
static func == (lhs: KeyObservationPair, rhs: KeyObservationPair) -> Bool {
lhs.key == rhs.key && lhs.observation == rhs.observation
}
func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: KeyObservationPairUnit.self)
try container.encode(key, forKey: .key)
try container.encode(observation, forKey: .observation)
}
func hash(into hasher: inout Hasher) {
hasher.combine(key)
hasher.combine(observation)
}
}
}
// MARK: - Internal Methods in LMAssembly.
extension LMAssembly.LMUserOverride {
func performObservation(
walkedBefore: [Megrez.Node], walkedAfter: [Megrez.Node],
cursor: Int, timestamp: Double, saveCallback: (() -> Void)? = nil
) {
//
guard !walkedAfter.isEmpty, !walkedBefore.isEmpty else { return }
guard walkedBefore.totalKeyCount == walkedAfter.totalKeyCount else { return }
//
var actualCursor = 0
guard let currentNode = walkedAfter.findNode(at: cursor, target: &actualCursor) else { return }
// 使
guard currentNode.spanLength <= 3 else { return }
//
guard actualCursor > 0 else { return } //
let currentNodeIndex = actualCursor
actualCursor -= 1
var prevNodeIndex = 0
guard let prevNode = walkedBefore.findNode(at: actualCursor, target: &prevNodeIndex) else { return }
let forceHighScoreOverride: Bool = currentNode.spanLength > prevNode.spanLength
let breakingUp = currentNode.spanLength == 1 && prevNode.spanLength > 1
let targetNodeIndex = breakingUp ? currentNodeIndex : prevNodeIndex
let key: String = LMAssembly.LMUserOverride.formObservationKey(
walkedNodes: walkedAfter, headIndex: targetNodeIndex
)
guard !key.isEmpty else { return }
doObservation(
key: key, candidate: currentNode.currentUnigram.value, timestamp: timestamp,
forceHighScoreOverride: forceHighScoreOverride, saveCallback: saveCallback
)
}
func fetchSuggestion(
currentWalk: [Megrez.Node], cursor: Int, timestamp: Double
) -> LMAssembly.OverrideSuggestion {
var headIndex = 0
guard let nodeIter = currentWalk.findNode(at: cursor, target: &headIndex) else { return .init() }
let key = LMAssembly.LMUserOverride.formObservationKey(walkedNodes: currentWalk, headIndex: headIndex)
return getSuggestion(key: key, timestamp: timestamp, headReading: nodeIter.joinedKey())
}
func bleachSpecifiedSuggestions(targets: [String], saveCallback: (() -> Void)? = nil) {
if targets.isEmpty { return }
for neta in mutLRUMap {
for target in targets {
if neta.value.observation.overrides.keys.contains(target) {
mutLRUMap.removeValue(forKey: neta.key)
}
}
}
resetMRUList()
saveCallback?() ?? saveData()
}
/// LRU
func bleachUnigrams(saveCallback: (() -> Void)? = nil) {
for key in mutLRUMap.keys {
if !key.contains("(),()") { continue }
mutLRUMap.removeValue(forKey: key)
}
resetMRUList()
saveCallback?() ?? saveData()
}
func resetMRUList() {
mutLRUList.removeAll()
for neta in mutLRUMap.reversed() {
mutLRUList.append(neta.value)
}
}
func clearData(withURL fileURL: URL? = nil) {
mutLRUMap = .init()
mutLRUList = .init()
do {
let nullData = "{}"
guard let fileURL = fileURL ?? fileSaveLocationURL else {
throw UOMError(rawValue: "given fileURL is invalid or nil.")
}
try nullData.write(to: fileURL, atomically: false, encoding: .utf8)
} catch {
vCLMLog("UOM Error: Unable to clear the data in the UOM file. Details: \(error)")
return
}
}
func saveData(toURL fileURL: URL? = nil) {
guard let fileURL: URL = fileURL ?? fileSaveLocationURL else {
vCLMLog("UOM saveData() failed. At least the file Save URL is not set for the current UOM.")
return
}
// 使 JSONSerialization
let encoder = JSONEncoder()
do {
guard let jsonData = try? encoder.encode(mutLRUMap) else { return }
try jsonData.write(to: fileURL, options: .atomic)
} catch {
vCLMLog("UOM Error: Unable to save data, abort saving. Details: \(error)")
return
}
}
func loadData(fromURL fileURL: URL? = nil) {
guard let fileURL: URL = fileURL ?? fileSaveLocationURL else {
vCLMLog("UOM loadData() failed. At least the file Load URL is not set for the current UOM.")
return
}
// 使 JSONSerialization
let decoder = JSONDecoder()
do {
let data = try Data(contentsOf: fileURL, options: .mappedIfSafe)
if ["", "{}"].contains(String(data: data, encoding: .utf8)) { return }
guard let jsonResult = try? decoder.decode([String: KeyObservationPair].self, from: data) else {
vCLMLog("UOM Error: Read file content type invalid, abort loading.")
return
}
mutLRUMap = jsonResult
resetMRUList()
} catch {
vCLMLog("UOM Error: Unable to read file or parse the data, abort loading. Details: \(error)")
return
}
}
}
// MARK: - Other Non-Public Internal Methods
extension LMAssembly.LMUserOverride {
func doObservation(
key: String, candidate: String, timestamp: Double, forceHighScoreOverride: Bool,
saveCallback: (() -> Void)?
) {
guard mutLRUMap[key] != nil else {
var observation: Observation = .init()
observation.update(candidate: candidate, timestamp: timestamp, forceHighScoreOverride: forceHighScoreOverride)
let koPair = KeyObservationPair(key: key, observation: observation)
// key key key
// Swift
mutLRUMap.removeValue(forKey: key)
mutLRUMap[key] = koPair
mutLRUList.insert(koPair, at: 0)
if mutLRUList.count > mutCapacity {
mutLRUMap.removeValue(forKey: mutLRUList[mutLRUList.endIndex - 1].key)
mutLRUList.removeLast()
}
vCLMLog("UOM: Observation finished with new observation: \(key)")
saveCallback?() ?? saveData()
return
}
// decayCallback
if var theNeta = mutLRUMap[key] {
theNeta.observation.update(
candidate: candidate, timestamp: timestamp, forceHighScoreOverride: forceHighScoreOverride
)
mutLRUList.insert(theNeta, at: 0)
mutLRUMap[key] = theNeta
vCLMLog("UOM: Observation finished with existing observation: \(key)")
saveCallback?() ?? saveData()
}
}
func getSuggestion(key: String, timestamp: Double, headReading: String) -> LMAssembly.OverrideSuggestion {
guard !key.isEmpty, let kvPair = mutLRUMap[key] else { return .init() }
let observation: Observation = kvPair.observation
var candidates: [(String, Megrez.Unigram)] = .init()
var forceHighScoreOverride = false
var currentHighScore: Double = 0
for (i, theObservation) in observation.overrides {
// Unigram
let isUnigramKey = key.contains("(),(),")
var decayExp = mutDecayExponent * (isUnigramKey ? 24 : 1)
// Unigram 12
if isUnigramKey, !key.replacingOccurrences(of: "(),(),", with: "").contains("-") { decayExp *= 12 }
let overrideScore = getScore(
eventCount: theObservation.count, totalCount: observation.count,
eventTimestamp: theObservation.timestamp, timestamp: timestamp, lambda: decayExp
)
if (0 ... currentHighScore).contains(overrideScore) { continue }
candidates.append((headReading, .init(value: i, score: overrideScore)))
forceHighScoreOverride = theObservation.forceHighScoreOverride
currentHighScore = overrideScore
}
return .init(candidates: candidates, forceHighScoreOverride: forceHighScoreOverride)
}
func getScore(
eventCount: Int,
totalCount: Int,
eventTimestamp: Double,
timestamp: Double,
lambda: Double
) -> Double {
let decay = exp((timestamp - eventTimestamp) * lambda)
if decay < kDecayThreshold { return 0.0 }
let prob = Double(eventCount) / Double(totalCount)
return prob * decay
}
static func isPunctuation(_ node: Megrez.Node) -> Bool {
for key in node.keyArray {
guard let firstChar = key.first else { continue }
return String(firstChar) == "_"
}
return false
}
static func formObservationKey(
walkedNodes: [Megrez.Node], headIndex cursorIndex: Int, readingOnly: Bool = false
) -> String {
// let whiteList = ""
var arrNodes: [Megrez.Node] = []
var intLength = 0
for theNodeAnchor in walkedNodes {
arrNodes.append(theNodeAnchor)
intLength += theNodeAnchor.spanLength
if intLength >= cursorIndex {
break
}
}
if arrNodes.isEmpty { return "" }
arrNodes = Array(arrNodes.reversed())
let kvCurrent = arrNodes[0].currentPair
guard !kvCurrent.joinedKey().contains("_") else {
return ""
}
//
if kvCurrent.keyArray.count != kvCurrent.value.count { return "" }
//
let strCurrent = kvCurrent.joinedKey()
var kvPrevious = Megrez.KeyValuePaired(keyArray: [""], value: "")
var kvAnterior = Megrez.KeyValuePaired(keyArray: [""], value: "")
var readingStack = ""
var trigramKey: String { "(\(kvAnterior.toNGramKey),\(kvPrevious.toNGramKey),\(strCurrent))" }
var result: String {
if readingStack.contains("_")
// 使
// kvCurrent
// || (!kvPrevious.isValid && kvCurrent.value.count == 1 && !whiteList.contains(kvCurrent.value))
{
return ""
} else {
return (readingOnly ? strCurrent : trigramKey)
}
}
func checkKeyValueValidityInThisContext(_ target: Megrez.KeyValuePaired) -> Bool {
!target.joinedKey().contains("_") && target.joinedKey().split(separator: "-").count == target.value.count
}
if arrNodes.count >= 2 {
let maybeKvPrevious = arrNodes[1].currentPair
if checkKeyValueValidityInThisContext(maybeKvPrevious) {
kvPrevious = maybeKvPrevious
readingStack = kvPrevious.joinedKey() + readingStack
}
}
if arrNodes.count >= 3 {
let maybeKvAnterior = arrNodes[2].currentPair
if checkKeyValueValidityInThisContext(maybeKvAnterior) {
kvAnterior = maybeKvAnterior
readingStack = kvAnterior.joinedKey() + readingStack
}
}
return result
}
}
struct UOMError: LocalizedError {
var rawValue: String
var errorDescription: String? {
NSLocalizedString("rawValue", comment: "")
}
}