Megrez // Allow resynchronizing unigram data in update().

This commit is contained in:
ShikiSuen 2022-11-30 21:11:52 +08:00
parent d870d5ad2a
commit 9028c6a5dd
5 changed files with 67 additions and 25 deletions

View File

@ -172,7 +172,7 @@ extension Megrez {
} }
} }
// MARK: - Internal Methods // MARK: - Internal Methods (Maybe Public)
extension Megrez.Compositor { extension Megrez.Compositor {
// MARK: Internal methods for maintaining the grid. // MARK: Internal methods for maintaining the grid.
@ -242,45 +242,51 @@ extension Megrez.Compositor {
return true return true
} }
func getJointKey(range: Range<Int>) -> String {
// contains macOS 13 Ventura
guard range.upperBound <= keys.count, range.lowerBound >= 0 else { return "" }
return keys[range].joined(separator: separator)
}
func getJointKeyArray(range: Range<Int>) -> [String] { func getJointKeyArray(range: Range<Int>) -> [String] {
// contains macOS 13 Ventura // contains macOS 13 Ventura
guard range.upperBound <= keys.count, range.lowerBound >= 0 else { return [] } guard range.upperBound <= keys.count, range.lowerBound >= 0 else { return [] }
return keys[range].map { String($0) } return keys[range].map { String($0) }
} }
func hasNode(at location: Int, length: Int, key: String) -> Bool { func getNode(at location: Int, length: Int, keyArray: [String]) -> Node? {
let location = max(min(location, spans.count), 0) // let location = max(min(location, spans.count), 0) //
guard let node = spans[location].nodeOf(length: length) else { return false } guard let node = spans[location].nodeOf(length: length) else { return nil }
return key == node.key return keyArray == node.keyArray ? node : nil
} }
/// ///
/// - Returns: /// - Returns: 0
@discardableResult mutating func update() -> Int { @discardableResult public mutating func update(updateExisting: Bool = false) -> Int {
let maxSpanLength = Megrez.Compositor.maxSpanLength let maxSpanLength = Megrez.Compositor.maxSpanLength
let range = max(0, cursor - maxSpanLength)..<min(cursor + maxSpanLength, keys.count) let range = max(0, cursor - maxSpanLength)..<min(cursor + maxSpanLength, keys.count)
var nodesInserted = 0 var nodesChanged = 0
for position in range { for position in range {
for theLength in 1...min(maxSpanLength, range.upperBound - position) { for theLength in 1...min(maxSpanLength, range.upperBound - position) {
let jointKeyArray = getJointKeyArray(range: position..<(position + theLength)) let jointKeyArray = getJointKeyArray(range: position..<(position + theLength))
let jointKey = getJointKey(range: position..<(position + theLength)) let jointKey = jointKeyArray.joined(separator: separator)
if hasNode(at: position, length: theLength, key: jointKey) { continue } if let theNode = getNode(at: position, length: theLength, keyArray: jointKeyArray) {
if !updateExisting { continue }
let unigrams = langModel.unigramsFor(key: jointKey)
//
if unigrams.isEmpty {
if theNode.keyArray.count == 1 { continue }
spans[position].nodes.removeAll { $0 == theNode }
} else {
theNode.resetUnigrams(using: unigrams)
}
nodesChanged += 1
continue
}
let unigrams = langModel.unigramsFor(key: jointKey) let unigrams = langModel.unigramsFor(key: jointKey)
guard !unigrams.isEmpty else { continue } guard !unigrams.isEmpty else { continue }
insertNode( insertNode(
.init(keyArray: jointKeyArray, spanLength: theLength, unigrams: unigrams, keySeparator: separator), .init(keyArray: jointKeyArray, spanLength: theLength, unigrams: unigrams, keySeparator: separator),
at: position at: position
) )
nodesInserted += 1 nodesChanged += 1
} }
} }
return nodesInserted return nodesChanged
} }
mutating func updateCursorJumpingTables(_ walkedNodes: [Node]) { mutating func updateCursorJumpingTables(_ walkedNodes: [Node]) {

View File

@ -5,15 +5,15 @@
extension Megrez.Compositor { extension Megrez.Compositor {
/// ///
public struct Span { public class Span {
private var nodes: [Node?] = [] public var nodes: [Node?] = []
public private(set) var maxLength = 0 public private(set) var maxLength = 0
private var maxSpanLength: Int { Megrez.Compositor.maxSpanLength } private var maxSpanLength: Int { Megrez.Compositor.maxSpanLength }
public init() { public init() {
clear() clear()
} }
public mutating func clear() { public func clear() {
nodes.removeAll() nodes.removeAll()
for _ in 0..<maxSpanLength { for _ in 0..<maxSpanLength {
nodes.append(nil) nodes.append(nil)
@ -24,7 +24,7 @@ extension Megrez.Compositor {
/// ///
/// - Parameter node: /// - Parameter node:
/// - Returns: /// - Returns:
@discardableResult public mutating func append(node: Node) -> Bool { @discardableResult public func append(node: Node) -> Bool {
guard (1...maxSpanLength).contains(node.spanLength) else { guard (1...maxSpanLength).contains(node.spanLength) else {
return false return false
} }
@ -36,7 +36,7 @@ extension Megrez.Compositor {
/// ///
/// - Parameter length: /// - Parameter length:
/// - Returns: /// - Returns:
@discardableResult public mutating func dropNodesOfOrBeyond(length: Int) -> Bool { @discardableResult public func dropNodesOfOrBeyond(length: Int) -> Bool {
guard (1...maxSpanLength).contains(length) else { guard (1...maxSpanLength).contains(length) else {
return false return false
} }
@ -66,7 +66,7 @@ extension Megrez.Compositor {
/// ///
/// - Parameter location: /// - Parameter location:
/// - Returns: /// - Returns:
func fetchOverlappingNodes(at location: Int) -> [NodeAnchor] { internal func fetchOverlappingNodes(at location: Int) -> [NodeAnchor] {
var results = [NodeAnchor]() var results = [NodeAnchor]()
guard !spans.isEmpty, location < spans.count else { return results } guard !spans.isEmpty, location < spans.count else { return results }

View File

@ -39,7 +39,7 @@ extension Megrez.Compositor {
public private(set) var spanLength: Int public private(set) var spanLength: Int
public private(set) var unigrams: [Megrez.Unigram] public private(set) var unigrams: [Megrez.Unigram]
public private(set) var currentUnigramIndex: Int = 0 { public private(set) var currentUnigramIndex: Int = 0 {
didSet { currentUnigramIndex = min(max(0, currentUnigramIndex), unigrams.count - 1) } didSet { currentUnigramIndex = max(min(unigrams.count - 1, currentUnigramIndex), 0) }
} }
public var currentPair: Megrez.Compositor.KeyValuePaired { .init(key: key, value: value) } public var currentPair: Megrez.Compositor.KeyValuePaired { .init(key: key, value: value) }
@ -53,6 +53,18 @@ extension Megrez.Compositor {
hasher.combine(overrideType) hasher.combine(overrideType)
} }
///
/// currentUnigramIndex 0
/// - Parameter source:
public func resetUnigrams(using source: [Megrez.Unigram]) {
let oldCurrentValue = unigrams[currentUnigramIndex].value
unigrams = source
// if unigrams.isEmpty { unigrams.append(.init(value: key, score: -114.514)) } //
currentUnigramIndex = max(min(unigrams.count - 1, currentUnigramIndex), 0)
let newCurrentValue = unigrams[currentUnigramIndex].value
if oldCurrentValue != newCurrentValue { currentUnigramIndex = 0 }
}
public private(set) var overrideType: Node.OverrideType public private(set) var overrideType: Node.OverrideType
public static func == (lhs: Node, rhs: Node) -> Bool { public static func == (lhs: Node, rhs: Node) -> Bool {

View File

@ -36,6 +36,13 @@ class SimpleLM: LangModelProtocol {
func hasUnigramsFor(key: String) -> Bool { func hasUnigramsFor(key: String) -> Bool {
mutDatabase.keys.contains(key) mutDatabase.keys.contains(key)
} }
func trim(key: String, value: String) {
guard var arr = mutDatabase[key] else { return }
arr = arr.compactMap { $0.value == value ? nil : $0 }
guard !arr.isEmpty else { return }
mutDatabase[key] = arr
}
} }
class MockLM: LangModelProtocol { class MockLM: LangModelProtocol {

View File

@ -11,7 +11,7 @@ import XCTest
final class MegrezTests: XCTestCase { final class MegrezTests: XCTestCase {
func testSpan() throws { func testSpan() throws {
let langModel = SimpleLM(input: strSampleData) let langModel = SimpleLM(input: strSampleData)
var span = Megrez.Compositor.Span() let span = Megrez.Compositor.Span()
let n1 = Megrez.Compositor.Node(keyArray: ["gao1"], spanLength: 1, unigrams: langModel.unigramsFor(key: "gao1")) let n1 = Megrez.Compositor.Node(keyArray: ["gao1"], spanLength: 1, unigrams: langModel.unigramsFor(key: "gao1"))
let n3 = Megrez.Compositor.Node( let n3 = Megrez.Compositor.Node(
keyArray: ["gao1ke1ji4"], spanLength: 3, unigrams: langModel.unigramsFor(key: "gao1ke1ji4") keyArray: ["gao1ke1ji4"], spanLength: 3, unigrams: langModel.unigramsFor(key: "gao1ke1ji4")
@ -518,4 +518,21 @@ final class MegrezTests: XCTestCase {
result = compositor.walk().0 result = compositor.walk().0
XCTAssertEqual(result.values, ["高熱", "🔥", "危險"]) XCTAssertEqual(result.values, ["高熱", "🔥", "危險"])
} }
func testCompositor_updateUnigramData() throws {
let theLM = SimpleLM(input: strSampleData)
var compositor = Megrez.Compositor(with: theLM)
compositor.separator = ""
compositor.insertKey("nian2")
compositor.insertKey("zhong1")
compositor.insertKey("jiang3")
compositor.insertKey("jin1")
let oldResult = compositor.walk().0.values.joined()
print(oldResult)
theLM.trim(key: "nian2zhong1", value: "年中")
compositor.update(updateExisting: true)
let newResult = compositor.walk().0.values.joined()
print(newResult)
XCTAssertEqual([oldResult, newResult], ["年中獎金", "年終獎金"])
}
} }