other: ecs 基本实现

This commit is contained in:
kercylan98 2024-04-10 19:08:07 +08:00
parent cc3573b792
commit dff6faa834
17 changed files with 823 additions and 0 deletions

8
toolkit/buffer/doc.go Normal file
View File

@ -0,0 +1,8 @@
// Package buffer 提供了缓冲区相关的实用程序。
//
// 包括创建、读取和写入缓冲区的函数。
//
// 这个包还提供了一个无界缓冲区的实现,可以在不使用额外 goroutine 的情况下实现无界缓冲区。
//
// 无界缓冲区的所有方法都是线程安全的,除了用于同步的互斥锁外,不会阻塞任何东西。
package buffer

7
toolkit/buffer/errors.go Normal file
View File

@ -0,0 +1,7 @@
package buffer
import "errors"
var (
ErrBufferIsEmpty = errors.New("buffer is empty")
)

148
toolkit/buffer/ring.go Normal file
View File

@ -0,0 +1,148 @@
package buffer
// NewRing 创建一个并发不安全的环形缓冲区
// - initSize: 初始容量
//
// 当初始容量小于 2 或未设置时,将会使用默认容量 2
func NewRing[T any](initSize ...int) *Ring[T] {
if len(initSize) == 0 {
initSize = append(initSize, 2)
}
if initSize[0] < 2 {
initSize[0] = 2
}
return &Ring[T]{
buf: make([]T, initSize[0]),
initSize: initSize[0],
size: initSize[0],
}
}
// Ring 环形缓冲区
type Ring[T any] struct {
buf []T
initSize int
size int
r int
w int
}
// Read 读取数据
func (b *Ring[T]) Read() (T, error) {
var t T
if b.r == b.w {
return t, ErrBufferIsEmpty
}
v := b.buf[b.r]
b.r++
if b.r == b.size {
b.r = 0
}
return v, nil
}
// ReadAll 读取所有数据
func (b *Ring[T]) ReadAll() []T {
if b.r == b.w {
return nil // 没有数据时返回空切片
}
var length int
var data []T
if b.w > b.r {
length = b.w - b.r
} else {
length = len(b.buf) - b.r + b.w
}
data = make([]T, length) // 预分配空间
if b.w > b.r {
copy(data, b.buf[b.r:b.w])
} else {
copied := copy(data, b.buf[b.r:])
copy(data[copied:], b.buf[:b.w])
}
b.r = 0
b.w = 0
return data
}
// Peek 查看数据
func (b *Ring[T]) Peek() (t T, err error) {
if b.r == b.w {
return t, ErrBufferIsEmpty
}
return b.buf[b.r], nil
}
// Write 写入数据
func (b *Ring[T]) Write(v T) {
b.buf[b.w] = v
b.w++
if b.w == b.size {
b.w = 0
}
if b.w == b.r {
b.grow()
}
}
// grow 扩容
func (b *Ring[T]) grow() {
var size int
if b.size < 1024 {
size = b.size * 2
} else {
size = b.size + b.size/4
}
buf := make([]T, size)
copy(buf[0:], b.buf[b.r:])
copy(buf[b.size-b.r:], b.buf[0:b.r])
b.r = 0
b.w = b.size
b.size = size
b.buf = buf
}
// IsEmpty 是否为空
func (b *Ring[T]) IsEmpty() bool {
return b.r == b.w
}
// Cap 返回缓冲区容量
func (b *Ring[T]) Cap() int {
return b.size
}
// Len 返回缓冲区长度
func (b *Ring[T]) Len() int {
if b.r == b.w {
return 0
}
if b.w > b.r {
return b.w - b.r
}
return b.size - b.r + b.w
}
// Reset 重置缓冲区
func (b *Ring[T]) Reset() {
b.r = 0
b.w = 0
b.size = b.initSize
b.buf = make([]T, b.initSize)
}

View File

@ -0,0 +1,25 @@
package buffer_test
import (
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func BenchmarkRing_Write(b *testing.B) {
ring := buffer.NewRing[int](1024)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ring.Write(i)
}
}
func BenchmarkRing_Read(b *testing.B) {
ring := buffer.NewRing[int](1024)
for i := 0; i < b.N; i++ {
ring.Write(i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ring.Read()
}
}

View File

@ -0,0 +1,14 @@
package buffer_test
import (
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func TestNewRing(t *testing.T) {
ring := buffer.NewRing[int]()
for i := 0; i < 100; i++ {
ring.Write(i)
t.Log(ring.Read())
}
}

View File

@ -0,0 +1,100 @@
package buffer
import (
"sync"
)
// NewRingUnbounded 创建一个并发安全的基于环形缓冲区实现的无界缓冲区
func NewRingUnbounded[T any](bufferSize int) *RingUnbounded[T] {
ru := &RingUnbounded[T]{
ring: NewRing[T](1024),
rc: make(chan T, bufferSize),
closedSignal: make(chan struct{}),
}
ru.cond = sync.NewCond(&ru.rrm)
ru.process()
return ru
}
// RingUnbounded 基于环形缓冲区实现的无界缓冲区
type RingUnbounded[T any] struct {
ring *Ring[T]
rrm sync.Mutex
cond *sync.Cond
rc chan T
closed bool
closedMutex sync.RWMutex
closedSignal chan struct{}
}
// Write 写入数据
func (b *RingUnbounded[T]) Write(v T) {
b.closedMutex.RLock()
defer b.closedMutex.RUnlock()
if b.closed {
return
}
b.rrm.Lock()
b.ring.Write(v)
b.cond.Signal()
b.rrm.Unlock()
}
// Read 读取数据
func (b *RingUnbounded[T]) Read() <-chan T {
return b.rc
}
// Closed 判断缓冲区是否已关闭
func (b *RingUnbounded[T]) Closed() bool {
b.closedMutex.RLock()
defer b.closedMutex.RUnlock()
return b.closed
}
// Close 关闭缓冲区,关闭后将不再接收新数据,但是已有数据仍然可以读取
func (b *RingUnbounded[T]) Close() <-chan struct{} {
b.closedMutex.Lock()
defer b.closedMutex.Unlock()
if b.closed {
return b.closedSignal
}
b.closed = true
b.rrm.Lock()
b.cond.Signal()
b.rrm.Unlock()
return b.closedSignal
}
func (b *RingUnbounded[T]) process() {
go func(b *RingUnbounded[T]) {
for {
b.closedMutex.RLock()
b.rrm.Lock()
vs := b.ring.ReadAll()
if len(vs) == 0 && !b.closed {
b.closedMutex.RUnlock()
b.cond.Wait()
} else {
b.closedMutex.RUnlock()
}
b.rrm.Unlock()
b.closedMutex.RLock()
if b.closed && len(vs) == 0 {
close(b.rc)
close(b.closedSignal)
b.closedMutex.RUnlock()
break
}
for _, v := range vs {
b.rc <- v
}
b.closedMutex.RUnlock()
}
}(b)
}

View File

@ -0,0 +1,48 @@
package buffer_test
import (
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func BenchmarkRingUnbounded_Write(b *testing.B) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ring.Write(i)
}
}
func BenchmarkRingUnbounded_Read(b *testing.B) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
for i := 0; i < b.N; i++ {
ring.Write(i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
<-ring.Read()
}
}
func BenchmarkRingUnbounded_Write_Parallel(b *testing.B) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ring.Write(1)
}
})
}
func BenchmarkRingUnbounded_Read_Parallel(b *testing.B) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
for i := 0; i < b.N; i++ {
ring.Write(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
<-ring.Read()
}
})
}

View File

@ -0,0 +1,33 @@
package buffer_test
import (
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func TestRingUnbounded_Write2Read(t *testing.T) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
for i := 0; i < 100; i++ {
ring.Write(i)
}
t.Log("write done")
for i := 0; i < 100; i++ {
t.Log(<-ring.Read())
}
t.Log("read done")
}
func TestRingUnbounded_Close(t *testing.T) {
ring := buffer.NewRingUnbounded[int](1024 * 16)
for i := 0; i < 100; i++ {
ring.Write(i)
}
t.Log("write done")
ring.Close()
t.Log("close done")
for v := range ring.Read() {
ring.Write(v)
t.Log(v)
}
t.Log("read done")
}

View File

@ -0,0 +1,79 @@
package buffer
import (
"sync"
)
// NewUnbounded 创建一个无界缓冲区
// - generateNil: 生成空值的函数,该函数仅需始终返回 nil 即可
//
// 该缓冲区来源于 gRPC 的实现,用于在不使用额外 goroutine 的情况下实现无界缓冲区
// - 该缓冲区的所有方法都是线程安全的,除了用于同步的互斥锁外,不会阻塞任何东西
func NewUnbounded[V any]() *Unbounded[V] {
return &Unbounded[V]{c: make(chan V, 1)}
}
// Unbounded 是无界缓冲区的实现
type Unbounded[V any] struct {
c chan V
closed bool
mu sync.Mutex
backlog []V
}
// Put 将数据放入缓冲区
func (slf *Unbounded[V]) Put(t V) {
slf.mu.Lock()
defer slf.mu.Unlock()
if slf.closed {
return
}
if len(slf.backlog) == 0 {
select {
case slf.c <- t:
return
default:
}
}
slf.backlog = append(slf.backlog, t)
}
// Load 将缓冲区中的数据发送到读取通道中,如果缓冲区中没有数据,则不会发送
// - 在每次 Get 后都应该执行该函数
func (slf *Unbounded[V]) Load() {
slf.mu.Lock()
defer slf.mu.Unlock()
if slf.closed {
return
}
if len(slf.backlog) > 0 {
select {
case slf.c <- slf.backlog[0]:
slf.backlog = slf.backlog[1:]
default:
}
}
}
// Get 获取读取通道
func (slf *Unbounded[V]) Get() <-chan V {
return slf.c
}
// Close 关闭
func (slf *Unbounded[V]) Close() {
slf.mu.Lock()
defer slf.mu.Unlock()
if slf.closed {
return
}
slf.closed = true
close(slf.c)
}
// IsClosed 是否已关闭
func (slf *Unbounded[V]) IsClosed() bool {
slf.mu.Lock()
defer slf.mu.Unlock()
return slf.closed
}

View File

@ -0,0 +1,50 @@
package buffer_test
import (
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func BenchmarkUnbounded_Write(b *testing.B) {
u := buffer.NewUnbounded[int]()
b.ResetTimer()
for i := 0; i < b.N; i++ {
u.Put(i)
}
}
func BenchmarkUnbounded_Read(b *testing.B) {
u := buffer.NewUnbounded[int]()
for i := 0; i < b.N; i++ {
u.Put(i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
<-u.Get()
u.Load()
}
}
func BenchmarkUnbounded_Write_Parallel(b *testing.B) {
u := buffer.NewUnbounded[int]()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
u.Put(1)
}
})
}
func BenchmarkUnbounded_Read_Parallel(b *testing.B) {
u := buffer.NewUnbounded[int]()
for i := 0; i < b.N; i++ {
u.Put(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
<-u.Get()
u.Load()
}
})
}

View File

@ -0,0 +1,17 @@
package buffer_test
import (
"fmt"
"github.com/kercylan98/minotaur/utils/buffer"
"testing"
)
func TestUnbounded_Get(t *testing.T) {
ub := buffer.NewUnbounded[int]()
for i := 0; i < 100; i++ {
ub.Put(i + 1)
fmt.Println(<-ub.Get())
//<-ub.Get()
ub.Load()
}
}

60
toolkit/ecs/archetype.go Normal file
View File

@ -0,0 +1,60 @@
package ecs
import (
"github.com/kercylan98/minotaur/utils/super"
"reflect"
)
func newArchetype(world *World, mask *super.BitSet[ComponentId]) *archetype {
arch := &archetype{
world: world,
mask: mask,
entityData: make(map[ComponentId][]reflect.Value),
}
return arch
}
// archetype 原型是一种实体的集合,它们都包含了相同的组件
type archetype struct {
world *World
mask *super.BitSet[ComponentId]
entities []Entity
entityData map[ComponentId][]reflect.Value
}
func (a *archetype) addEntity(entity Entity) Entity {
entity.setArchetypeIndex(len(a.entities))
a.entities = append(a.entities, entity)
for _, componentId := range a.mask.Bits() {
t := a.world.getComponentTypeById(componentId)
if t == nil {
continue
}
v := reflect.New(t)
a.entityData[componentId] = append(a.entityData[componentId], v)
}
return entity
}
func (a *archetype) removeEntity(entity Entity) {
idx := entity.GetArchetypeIndex()
for componentId, values := range a.entityData {
a.entityData[componentId] = append(values[:idx], values[idx+1:]...)
}
a.entities = append(a.entities[:idx], a.entities[idx+1:]...)
}
func (a *archetype) getEntityComponentData(entity Entity, componentId ComponentId) reflect.Value {
return a.entityData[componentId][entity.GetArchetypeIndex()]
}
func (a *archetype) getEntityData(entity Entity) []reflect.Value {
var idx = entity.GetArchetypeIndex()
var data []reflect.Value
for _, componentId := range a.mask.Bits() {
data = append(data, a.entityData[componentId][idx])
}
return data
}

8
toolkit/ecs/component.go Normal file
View File

@ -0,0 +1,8 @@
package ecs
type ComponentId = int
// Component 组件是一个数据结构,它包含了一些数据
type Component struct {
id ComponentId
}

21
toolkit/ecs/entity.go Normal file
View File

@ -0,0 +1,21 @@
package ecs
type EntityId = int
// Entity 仅包含一个实体的唯一标识
type Entity struct {
id EntityId // 实体的唯一标识
archIdx int // 实体所在的原型索引
}
func (e Entity) GetId() EntityId {
return e.id
}
func (e Entity) GetArchetypeIndex() int {
return e.archIdx
}
func (e Entity) setArchetypeIndex(idx int) {
e.archIdx = idx
}

69
toolkit/ecs/query.go Normal file
View File

@ -0,0 +1,69 @@
package ecs
import (
"github.com/kercylan98/minotaur/utils/super"
"reflect"
)
func QueryEntity[C any](world *World, entity Entity) *C {
t := reflect.TypeOf((*C)(nil)).Elem()
id, exist := world.components[t]
if !exist {
var ids []ComponentId
for i := range t.NumField() {
ids = append(ids, world.ComponentId(t.Field(i).Type.Elem()))
}
mask := super.NewBitSet(ids...)
for _, arch := range world.archetypes {
if !arch.mask.ContainsAll(mask) {
continue
}
values := arch.getEntityData(entity)
fields := make(map[reflect.Type]reflect.Value)
for _, value := range values {
fields[value.Type()] = value
}
result := reflect.New(t)
for i := range t.NumField() {
f := result.Elem().Field(i)
f.Set(fields[f.Type()])
}
return result.Interface().(*C)
}
return nil
}
for _, arch := range world.archetypes {
if !arch.mask.Has(id) {
continue
}
for _, e := range arch.entities {
if e == entity {
return arch.getEntityComponentData(entity, id).Interface().(*C)
}
}
}
return nil
}
func QueryEntitiesByComponentId[T any](world *World, id ComponentId) []*T {
t := reflect.TypeOf((*T)(nil)).Elem()
if world.components[t] != id {
return nil
}
var cs []*T
for _, arch := range world.archetypes {
if arch.mask.Has(id) {
for _, entity := range arch.entities {
arch.getEntityComponentData(entity, id)
cs = append(cs, arch.getEntityComponentData(entity, id).Interface().(*T))
}
}
}
return cs
}

92
toolkit/ecs/world.go Normal file
View File

@ -0,0 +1,92 @@
package ecs
import (
"github.com/kercylan98/minotaur/utils/super"
"reflect"
)
// NewWorld 创建一个新的世界
func NewWorld() World {
return World{
components: make(map[reflect.Type]int),
componentTypes: make(map[int]reflect.Type),
archetypes: make(map[*super.BitSet[int]]*archetype),
}
}
type World struct {
componentIds []ComponentId // 已经注册的组件 Id 清单
components map[reflect.Type]ComponentId // 已经注册的组件清单
componentTypes map[ComponentId]reflect.Type // 已经注册的组件类型清单
archetypes map[*super.BitSet[ComponentId]]*archetype // 已经注册的原型清单
entityGuid EntityId // 实体的唯一标识当前值
}
// CreateEntity 创建一个新的实体
func (w *World) CreateEntity(componentId ComponentId, componentIds ...ComponentId) Entity {
mask := super.NewBitSet(append([]ComponentId{componentId}, componentIds...)...)
var arch *archetype
for existingMask, existingArch := range w.archetypes {
if existingMask.Equal(mask) {
arch = existingArch
break
}
}
if arch == nil {
arch = newArchetype(w, mask)
w.archetypes[mask] = arch
}
return arch.addEntity(Entity{
id: w.entityGuid,
})
}
// ComponentId 返回一个组件的 Id如果组件未注册则注册它
func (w *World) ComponentId(t reflect.Type) ComponentId {
id, exist := w.components[t]
if !exist {
id = len(w.components)
w.components[t] = id
w.componentTypes[id] = t
w.componentIds = append(w.componentIds, id)
}
return id
}
// getComponentTypeById 通过 Id 获取一个组件的类型
func (w *World) getComponentTypeById(id ComponentId) reflect.Type {
if id < 0 || id >= len(w.componentIds) {
return nil
}
return w.componentTypes[id]
}
// unregisterComponentById 通过 Id 注销一个组件
func (w *World) unregisterComponentById(id ComponentId) {
t := w.componentTypes[id]
delete(w.components, t)
delete(w.componentTypes, id)
w.componentIds = append(w.componentIds[:id], w.componentIds[id+1:]...)
}
// unregisterComponentByType 通过类型注销一个组件
func (w *World) unregisterComponentByType(t reflect.Type) {
id, exist := w.components[t]
if !exist {
return
}
w.unregisterComponentById(id)
}
// nextEntityGuid 返回下一个实体的唯一标识
func (w *World) nextEntityGuid() EntityId {
guid := w.entityGuid
w.entityGuid++
return guid
}

44
toolkit/ecs/world_test.go Normal file
View File

@ -0,0 +1,44 @@
package ecs_test
import (
"github.com/kercylan98/minotaur/toolkit/ecs"
"github.com/kercylan98/minotaur/utils/super"
"reflect"
"testing"
)
type NameComponent struct {
Name string
}
type AgeComponent struct {
Age int
}
func TestWorld_ComponentId(t *testing.T) {
w := ecs.NewWorld()
nameComponent := w.ComponentId(reflect.TypeOf(NameComponent{}))
ageComponent := w.ComponentId(reflect.TypeOf(AgeComponent{}))
ea := w.CreateEntity(nameComponent, ageComponent)
eb := w.CreateEntity(nameComponent, ageComponent)
ecs.QueryEntity[NameComponent](&w, ea).Name = "Alice"
ecs.QueryEntity[NameComponent](&w, eb).Name = "Bob"
ecs.QueryEntity[AgeComponent](&w, ea).Age = 20
ecs.QueryEntity[AgeComponent](&w, eb).Age = 30
t.Log(string(super.MarshalJSON(ecs.QueryEntity[NameComponent](&w, ea))))
t.Log(string(super.MarshalJSON(ecs.QueryEntity[NameComponent](&w, eb))))
t.Log(string(super.MarshalJSON(ecs.QueryEntity[AgeComponent](&w, ea))))
t.Log(string(super.MarshalJSON(ecs.QueryEntity[AgeComponent](&w, eb))))
merge := ecs.QueryEntity[struct {
*NameComponent
*AgeComponent
}](&w, ea)
t.Log(string(super.MarshalJSON(merge)))
}