UserPhrase // Improve score boosting / nerfing for single kanji.

This commit is contained in:
ShikiSuen 2024-03-02 20:15:25 +08:00
parent b628ddd082
commit 35d4426730
3 changed files with 75 additions and 3 deletions

View File

@ -6,8 +6,10 @@
// marks, or product names of Contributor, except as required to fulfill notice
// requirements defined in MIT License.
import CandidateWindow
import Foundation
import LangModelAssembly
import Megrez
import Shared
// MARK: - 使
@ -32,6 +34,10 @@ public extension LMMgr {
!keyArray.isEmpty && keyArray.filter(\.isEmpty).isEmpty && !value.isEmpty
}
var isSingleCharReadingPair: Bool {
value.count == 1 && keyArray.count == 1 && keyArray.first?.first != "_"
}
public var description: String {
descriptionCells.joined(separator: " ")
}
@ -190,3 +196,45 @@ public extension LMMgr {
}
}
}
// MARK: - Weight Suggestions.
public extension LMMgr.UserPhrase {
mutating func updateWeight(basedOn action: CandidateContextMenuAction) {
weight = suggestNextFreq(for: action)
}
func suggestNextFreq(for action: CandidateContextMenuAction, extreme: Bool = false) -> Double? {
var extremeFallbackResult: Double? {
switch action {
case .toBoost: return nil // 0
case .toNerf: return -114.514
case .toFilter: return nil
}
}
guard !extreme, isSingleCharReadingPair else { return extremeFallbackResult }
let factoryUnigrams = inputMode.langModel.factoryCoreUnigramsFor(key: keyArray.joined(separator: "-"))
let currentWeight = weight ?? factoryUnigrams.first { $0.value == value }?.score
guard let currentWeight = currentWeight else { return extremeFallbackResult }
let factoryScores = factoryUnigrams.map(\.score)
var neighborValue: Double?
switch action {
case .toBoost:
neighborValue = currentWeight.findNeighborValue(from: factoryScores, greater: true)
if let realNeighborValue = neighborValue {
neighborValue = realNeighborValue + 0.0001
} else {
neighborValue = Swift.min(0, currentWeight + 1)
}
case .toNerf:
neighborValue = currentWeight.findNeighborValue(from: factoryScores, greater: false)
if let realNeighborValue = neighborValue {
neighborValue = realNeighborValue - 0.0001
} else {
neighborValue = Swift.max(-114.514, currentWeight - 1)
}
case .toFilter: return nil
}
return neighborValue ?? extremeFallbackResult
}
}

View File

@ -46,8 +46,11 @@ extension SessionCtl: InputHandlerDelegate {
var userPhrase = LMMgr.UserPhrase(
keyArray: kvPair.keyArray, value: kvPair.value, inputMode: inputMode
)
if Self.areWeNerfing { userPhrase.weight = -114.514 }
LMMgr.writeUserPhrasesAtOnce(userPhrase, areWeFiltering: addToFilter) {
var action = CandidateContextMenuAction.toBoost
if Self.areWeNerfing { action = .toNerf }
if addToFilter { action = .toFilter }
userPhrase.updateWeight(basedOn: action)
LMMgr.writeUserPhrasesAtOnce(userPhrase, areWeFiltering: action == .toFilter) {
succeeded = false
}
if !succeeded { return false }
@ -275,7 +278,7 @@ extension SessionCtl: CtlCandidateDelegate {
var userPhrase = LMMgr.UserPhrase(
keyArray: rawPair.keyArray, value: rawPair.value, inputMode: inputMode
)
if action == .toNerf { userPhrase.weight = -114.514 }
userPhrase.updateWeight(basedOn: action)
LMMgr.writeUserPhrasesAtOnce(userPhrase, areWeFiltering: action == .toFilter) {
succeeded = false
}

View File

@ -398,3 +398,24 @@ public enum ArrayBuilder<OutputModel> {
Array(components.joined())
}
}
// MARK: - Extending Comparable to let it able to find its neighbor values in any collection.
public extension Comparable {
func findNeighborValue(from givenSeq: any Collection<Self>, greater isGreater: Bool) -> Self? {
let givenArray: [Self] = isGreater ? Array(givenSeq.sorted()) : Array(givenSeq.sorted().reversed())
let givenMap: [Int: Self] = .init(uniqueKeysWithValues: Array(givenArray.enumerated()))
var (startID, endID, returnableID) = (0, givenArray.count - 1, -1)
func internalCompare(_ lhs: Self, _ rhs: Self) -> Bool { isGreater ? lhs <= rhs : lhs >= rhs }
while let startObj = givenMap[startID], let endObj = givenMap[endID], internalCompare(startObj, endObj) {
let midID = (startID + endID) / 2
if let midObj = givenMap[midID], internalCompare(midObj, self) {
startID = midID + 1
} else {
returnableID = midID
endID = midID - 1
}
}
return givenMap[returnableID]
}
}