feat: 移除 super.BitMask 以 super.BitSet 替代,super.BitSet 是一个可动态增长的比特位集合

This commit is contained in:
kercylan98 2023-12-15 16:21:03 +08:00
parent 70694311c6
commit 05c65e9efd
4 changed files with 357 additions and 112 deletions

324
utils/super/bit_set.go Normal file
View File

@ -0,0 +1,324 @@
package super
import (
"fmt"
"github.com/kercylan98/minotaur/utils/generic"
"math/bits"
)
// NewBitSet 通过指定的 Bit 位创建一个 BitSet
func NewBitSet[Bit generic.Integer](bits ...Bit) *BitSet[Bit] {
set := &BitSet[Bit]{set: make([]uint64, 0, 1)}
for _, bit := range bits {
set.Set(bit)
}
return set
}
// BitSet 是一个可以动态增长的比特位集合
// - 默认情况下将使用 64 位无符号整数来表示比特位,当需要表示的比特位超过 64 位时,将自动增长
type BitSet[Bit generic.Integer] struct {
set []uint64 // 比特位集合
}
// Set 将指定的位 bit 设置为 1
func (slf *BitSet[Bit]) Set(bit Bit) *BitSet[Bit] {
word := bit >> 6
for word >= Bit(len(slf.set)) {
slf.set = append(slf.set, 0)
}
slf.set[word] |= 1 << (bit & 0x3f)
return slf
}
// Del 将指定的位 bit 设置为 0
func (slf *BitSet[Bit]) Del(bit Bit) *BitSet[Bit] {
word := bit >> 6
if word < Bit(len(slf.set)) {
slf.set[word] &^= 1 << (bit & 0x3f)
}
return slf
}
// Shrink 将 BitSet 中的比特位集合缩小到最小
// - 正常情况下当 BitSet 中的比特位超出 64 位时,将自动增长,当 BitSet 中的比特位数量减少时,可以使用该方法将 BitSet 中的比特位集合缩小到最小
func (slf *BitSet[Bit]) Shrink() *BitSet[Bit] {
index := len(slf.set) - 1
if slf.set[index] != 0 {
return slf
}
for i := index - 1; i >= 0; i-- {
if slf.set[i] != 0 {
slf.set = slf.set[:i+1]
return slf
}
}
return slf
}
// Cap 返回当前 BitSet 中可以表示的最大比特位数量
func (slf *BitSet[Bit]) Cap() int {
return len(slf.set) * 64
}
// Has 检查指定的位 bit 是否被设置为 1
func (slf *BitSet[Bit]) Has(bit Bit) bool {
word := bit >> 6
return word < Bit(len(slf.set)) && slf.set[word]&(1<<(bit&0x3f)) != 0
}
// Clear 清空所有的比特位
func (slf *BitSet[Bit]) Clear() *BitSet[Bit] {
slf.set = nil
return slf
}
// Len 返回当前 BitSet 中被设置的比特位数量
func (slf *BitSet[Bit]) Len() int {
var count int
for _, word := range slf.set {
count += bits.OnesCount64(word)
}
return count
}
// Bits 返回当前 BitSet 中被设置的比特位
func (slf *BitSet[Bit]) Bits() []Bit {
bits := make([]Bit, 0, slf.Len())
for i, word := range slf.set {
for j := 0; j < 64; j++ {
if word&(1<<j) != 0 {
bits = append(bits, Bit(i*64+j))
}
}
}
return bits
}
// Reverse 反转当前 BitSet 中的所有比特位
func (slf *BitSet[Bit]) Reverse() *BitSet[Bit] {
for i, word := range slf.set {
slf.set[i] = bits.Reverse64(word)
}
return slf
}
// Not 返回当前 BitSet 中所有比特位的反转
func (slf *BitSet[Bit]) Not() *BitSet[Bit] {
for i, word := range slf.set {
slf.set[i] = ^word
}
return slf
}
// And 将当前 BitSet 与另一个 BitSet 进行按位与运算
func (slf *BitSet[Bit]) And(other *BitSet[Bit]) *BitSet[Bit] {
for i, word := range other.set {
if i < len(slf.set) {
slf.set[i] &= word
} else {
slf.set = append(slf.set, word)
}
}
return slf
}
// Or 将当前 BitSet 与另一个 BitSet 进行按位或运算
func (slf *BitSet[Bit]) Or(other *BitSet[Bit]) *BitSet[Bit] {
for i, word := range other.set {
if i < len(slf.set) {
slf.set[i] |= word
} else {
slf.set = append(slf.set, word)
}
}
return slf
}
// Xor 将当前 BitSet 与另一个 BitSet 进行按位异或运算
func (slf *BitSet[Bit]) Xor(other *BitSet[Bit]) *BitSet[Bit] {
for i, word := range other.set {
if i < len(slf.set) {
slf.set[i] ^= word
} else {
slf.set = append(slf.set, word)
}
}
return slf
}
// Sub 将当前 BitSet 与另一个 BitSet 进行按位减运算
func (slf *BitSet[Bit]) Sub(other *BitSet[Bit]) *BitSet[Bit] {
for i, word := range other.set {
if i < len(slf.set) {
slf.set[i] &^= word
}
}
return slf
}
// IsZero 检查当前 BitSet 是否为空
func (slf *BitSet[Bit]) IsZero() bool {
for _, word := range slf.set {
if word != 0 {
return false
}
}
return true
}
// Clone 返回当前 BitSet 的副本
func (slf *BitSet[Bit]) Clone() *BitSet[Bit] {
other := &BitSet[Bit]{set: make([]uint64, len(slf.set))}
copy(other.set, slf.set)
return other
}
// Equal 检查当前 BitSet 是否与另一个 BitSet 相等
func (slf *BitSet[Bit]) Equal(other *BitSet[Bit]) bool {
if len(slf.set) != len(other.set) {
return false
}
for i, word := range slf.set {
if word != other.set[i] {
return false
}
}
return true
}
// Contains 检查当前 BitSet 是否包含另一个 BitSet
func (slf *BitSet[Bit]) Contains(other *BitSet[Bit]) bool {
for i, word := range other.set {
if i >= len(slf.set) || slf.set[i]&word != word {
return false
}
}
return true
}
// ContainsAny 检查当前 BitSet 是否包含另一个 BitSet 中的任意比特位
func (slf *BitSet[Bit]) ContainsAny(other *BitSet[Bit]) bool {
for i, word := range other.set {
if i < len(slf.set) && slf.set[i]&word != 0 {
return true
}
}
return false
}
// ContainsAll 检查当前 BitSet 是否包含另一个 BitSet 中的所有比特位
func (slf *BitSet[Bit]) ContainsAll(other *BitSet[Bit]) bool {
for i, word := range other.set {
if i >= len(slf.set) || slf.set[i]&word != word {
return false
}
}
return true
}
// Intersect 检查当前 BitSet 是否与另一个 BitSet 有交集
func (slf *BitSet[Bit]) Intersect(other *BitSet[Bit]) bool {
for i, word := range other.set {
if i < len(slf.set) && slf.set[i]&word != 0 {
return true
}
}
return false
}
// Union 检查当前 BitSet 是否与另一个 BitSet 有并集
func (slf *BitSet[Bit]) Union(other *BitSet[Bit]) bool {
for i, word := range other.set {
if i < len(slf.set) && slf.set[i]&word != 0 {
return true
}
}
return false
}
// Difference 检查当前 BitSet 是否与另一个 BitSet 有差集
func (slf *BitSet[Bit]) Difference(other *BitSet[Bit]) bool {
for i, word := range other.set {
if i < len(slf.set) && slf.set[i]&word != 0 {
return true
}
}
return false
}
// SymmetricDifference 检查当前 BitSet 是否与另一个 BitSet 有对称差集
func (slf *BitSet[Bit]) SymmetricDifference(other *BitSet[Bit]) bool {
for i, word := range other.set {
if i < len(slf.set) && slf.set[i]&word != 0 {
return true
}
}
return false
}
// Subset 检查当前 BitSet 是否为另一个 BitSet 的子集
func (slf *BitSet[Bit]) Subset(other *BitSet[Bit]) bool {
for i, word := range other.set {
if i >= len(slf.set) || slf.set[i]&word != word {
return false
}
}
return true
}
// Superset 检查当前 BitSet 是否为另一个 BitSet 的超集
func (slf *BitSet[Bit]) Superset(other *BitSet[Bit]) bool {
for i, word := range slf.set {
if i >= len(other.set) || other.set[i]&word != word {
return false
}
}
return true
}
// Complement 检查当前 BitSet 是否为另一个 BitSet 的补集
func (slf *BitSet[Bit]) Complement(other *BitSet[Bit]) bool {
for i, word := range slf.set {
if i >= len(other.set) || other.set[i]&word != word {
return false
}
}
return true
}
// Max 返回当前 BitSet 中最大的比特位
func (slf *BitSet[Bit]) Max() Bit {
for i := len(slf.set) - 1; i >= 0; i-- {
if slf.set[i] != 0 {
return Bit(i*64 + bits.Len64(slf.set[i]) - 1)
}
}
return 0
}
// Min 返回当前 BitSet 中最小的比特位
func (slf *BitSet[Bit]) Min() Bit {
for i, word := range slf.set {
if word != 0 {
return Bit(i*64 + bits.TrailingZeros64(word))
}
}
return 0
}
// String 返回当前 BitSet 的字符串表示
func (slf *BitSet[Bit]) String() string {
return fmt.Sprintf("[%v] %v", slf.Len(), slf.Bits())
}
// MarshalJSON 实现 json.Marshaler 接口
func (slf *BitSet[Bit]) MarshalJSON() ([]byte, error) {
return MarshalJSONE(slf.set)
}
// UnmarshalJSON 实现 json.Unmarshaler 接口
func (slf *BitSet[Bit]) UnmarshalJSON(data []byte) error {
return UnmarshalJSON(data, &slf.set)
}

View File

@ -0,0 +1,33 @@
package super_test
import (
"github.com/kercylan98/minotaur/utils/super"
"testing"
)
func TestBitSet_Set(t *testing.T) {
bs := super.NewBitSet(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
bs.Set(11)
bs.Set(12)
bs.Set(13)
t.Log(bs)
}
func TestBitSet_Del(t *testing.T) {
bs := super.NewBitSet(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
bs.Del(11)
bs.Del(12)
bs.Del(13)
bs.Del(10)
t.Log(bs)
}
func TestBitSet_Shrink(t *testing.T) {
bs := super.NewBitSet(63)
t.Log(bs.Cap())
bs.Set(200)
t.Log(bs.Cap())
bs.Del(200)
bs.Shrink()
t.Log(bs.Cap())
}

View File

@ -1,96 +0,0 @@
package super
import (
"github.com/kercylan98/minotaur/utils/generic"
"math/bits"
"strconv"
)
// BitMask 使用泛型类型 Bit 表示一个比特掩码Bit 必须是整数类型。
type BitMask[Bit generic.Integer] uint64
// Mask 创建一个新的 BitMask包含给定的比特位。
// 参数 bits 是一个可变参数列表,表示要设置的比特位。
// - 使用按位或运算符 (|=) 将 bit 位置的比特设置为 1。
func Mask[Bit generic.Integer](bits ...Bit) BitMask[Bit] {
var mask Bit
for _, bit := range bits {
mask |= 1 << bit
}
return BitMask[Bit](mask)
}
// Matches 检查当前 BitMask 是否与另一个 BitMask 完全匹配。
func (slf *BitMask[Bit]) Matches(bits BitMask[Bit]) bool {
return *slf == bits
}
// Contains 检查当前 BitMask 是否包含另一个 BitMask。
// - 使用按位与运算符 (&) 计算两个 BitMask 的公共部分,然后与 bits 进行比较。
func (slf *BitMask[Bit]) Contains(bits BitMask[Bit]) bool {
return (*slf & bits) == bits
}
// Has 检查当前 BitMask 是否包含特定的比特位。
// - 使用按位与运算符 (&) 检查 bit 位置的比特是否为 1。
func (slf *BitMask[Bit]) Has(bits ...Bit) bool {
for _, bit := range bits {
if *slf&(1<<bit) == 0 {
return false
}
}
return true
}
// Add 将给定的比特位添加到当前 BitMask 中。
// - 使用按位或运算符 (|=) 将 bit 位置的比特设置为 1。
func (slf *BitMask[Bit]) Add(bit Bit) {
*slf |= 1 << bit
}
// Del 从当前 BitMask 中删除给定的比特位。
// - 使用按位与运算符 (&) 取反 (^) 然后按位与运算符 (&) 清除 bit 位置的比特。
func (slf *BitMask[Bit]) Del(bit Bit) {
*slf &= ^(1 << bit)
}
// Toggle 反转指定比特位。
// - 使用按位或运算符 (|=) 将 bit 位置的比特设置为 1如果 bit 位置的比特已经为 1则使用按位与运算符 (&^=) 将 bit 位置的比特设置为 0。
func (slf *BitMask[Bit]) Toggle(bit Bit) {
*slf ^= 1 << bit
}
// Reset 将当前 BitMask 重置为 all 0。
func (slf *BitMask[Bit]) Reset() {
*slf = 0
}
// Count 统计 BitMask 中被设置的比特位数量。
// - 使用按位与运算符 (&) 将 BitMask 与 0xFFFFFFFF 进行按位与运算,然后使用 `bits.Len()` 函数统计 0 的数量。
func (slf *BitMask[Bit]) Count() int {
return 64 - bits.Len64(uint64(*slf)&0xFFFFFFFF)
}
// Union 返回当前 BitMask 和另一个 BitMask 的并集。
// - 使用按位或运算符 (&) 将两个 BitMask 进行合并。
func (slf *BitMask[Bit]) Union(other BitMask[Bit]) BitMask[Bit] {
return *slf | other
}
// Intersection 返回当前 BitMask 和另一个 BitMask 的交集。
// - 使用按位与运算符 (&) 将两个 BitMask 进行交集。
func (slf *BitMask[Bit]) Intersection(other BitMask[Bit]) BitMask[Bit] {
return *slf & other
}
// Difference 返回当前 BitMask 和另一个 BitMask 的差集。
// - 使用按位异或运算符 (^) 将两个 BitMask 进行异或。
func (slf *BitMask[Bit]) Difference(other BitMask[Bit]) BitMask[Bit] {
return *slf ^ other
}
// String 返回当前 BitMask 的字符串表示形式。
// - 使用 `strconv.FormatUint()` 函数将 BitMask 转换为二进制字符串。
func (slf *BitMask[Bit]) String() string {
return strconv.FormatUint(uint64(*slf), 2)
}

View File

@ -1,16 +0,0 @@
package super_test
import (
"github.com/kercylan98/minotaur/utils/super"
"testing"
)
func TestMask(t *testing.T) {
mask := super.Mask(1, 2, 3, 5)
t.Log(mask.Matches(super.Mask(1, 2, 3, 5)))
t.Log(mask.Matches(super.Mask(1, 2, 3, 4, 5)))
t.Log(mask.Matches(super.Mask(1, 2, 5)))
t.Log(mask.Contains(super.Mask(1, 2, 5)))
t.Log(mask)
}