From 3408c212d0fe4e7f0202c5e7facfcbd69f3712c1 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Mon, 8 Jan 2024 19:10:12 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=20server=20?= =?UTF-8?q?=E5=8C=85=E5=88=86=E6=B5=81=E6=B8=A0=E9=81=93=E8=AE=BE=E8=AE=A1?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E5=A4=8D=E9=83=A8=E5=88=86=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 使用 RingBuffer 实现分流渠道的无界缓冲区,修复分流渠道被关闭后,未处理完成的消息将会被丢弃的问题; - 移除 server.WithDisableAutomaticReleaseShunt 可选项,分流渠道将在消息处理完毕且没有连接使用时自行释放; --- server/conn.go | 3 - server/constants.go | 5 +- server/dispatcher.go | 83 ------- server/event.go | 1 + server/internal/dispatcher/dispatcher.go | 153 +++++++++++++ server/internal/dispatcher/dispatcher_test.go | 48 ++++ server/internal/dispatcher/manager.go | 153 +++++++++++++ server/internal/dispatcher/manager_test.go | 46 ++++ server/internal/dispatcher/message.go | 6 + server/internal/dispatcher/producer.go | 5 + server/message.go | 18 ++ server/options.go | 64 +++--- server/server.go | 209 ++++++------------ server/server_test.go | 4 +- utils/buffer/ring.go | 129 +++++++---- utils/buffer/ring_benchmark_test.go | 25 +++ utils/buffer/ring_test.go | 14 ++ utils/buffer/ring_unbounded.go | 100 +++++++++ utils/buffer/ring_unbounded_benchmark_test.go | 48 ++++ utils/buffer/ring_unbounded_test.go | 33 +++ utils/buffer/unbounded_benchmark_test.go | 41 +++- 21 files changed, 863 insertions(+), 325 deletions(-) delete mode 100644 server/dispatcher.go create mode 100644 server/internal/dispatcher/dispatcher.go create mode 100644 server/internal/dispatcher/dispatcher_test.go create mode 100644 server/internal/dispatcher/manager.go create mode 100644 server/internal/dispatcher/manager_test.go create mode 100644 server/internal/dispatcher/message.go create mode 100644 server/internal/dispatcher/producer.go create mode 100644 utils/buffer/ring_benchmark_test.go create mode 100644 utils/buffer/ring_test.go create mode 100644 utils/buffer/ring_unbounded.go create mode 100644 utils/buffer/ring_unbounded_benchmark_test.go create mode 100644 utils/buffer/ring_unbounded_test.go diff --git a/server/conn.go b/server/conn.go index ad5ee9f..de38c4c 100644 --- a/server/conn.go +++ b/server/conn.go @@ -366,7 +366,4 @@ func (slf *Conn) Close(err ...error) { return } slf.server.OnConnectionClosedEvent(slf, nil) - if !slf.server.runtime.disableAutomaticReleaseShunt { - slf.server.releaseDispatcher(slf) - } } diff --git a/server/constants.go b/server/constants.go index d08a16b..b5bc5c7 100644 --- a/server/constants.go +++ b/server/constants.go @@ -7,9 +7,8 @@ import ( ) const ( - serverMultipleMark = "Minotaur Multiple Server" - serverMark = "Minotaur Server" - serverSystemDispatcher = "__system" // 系统消息分发器 + serverMultipleMark = "Minotaur Multiple Server" + serverMark = "Minotaur Server" ) const ( diff --git a/server/dispatcher.go b/server/dispatcher.go deleted file mode 100644 index 31c8e73..0000000 --- a/server/dispatcher.go +++ /dev/null @@ -1,83 +0,0 @@ -package server - -import ( - "github.com/alphadose/haxmap" - "github.com/kercylan98/minotaur/utils/buffer" - "sync/atomic" - "time" -) - -var dispatcherUnique = struct{}{} - -// generateDispatcher 生成消息分发器 -func generateDispatcher(name string, handler func(dispatcher *dispatcher, message *Message)) *dispatcher { - d := &dispatcher{ - name: name, - buf: buffer.NewUnbounded[*Message](), - handler: handler, - uniques: haxmap.New[string, struct{}](), - } - return d -} - -// dispatcher 消息分发器 -type dispatcher struct { - name string - buf *buffer.Unbounded[*Message] - uniques *haxmap.Map[string, struct{}] - handler func(dispatcher *dispatcher, message *Message) - closed uint32 - msgCount int64 -} - -func (d *dispatcher) unique(name string) bool { - _, loaded := d.uniques.GetOrSet(name, dispatcherUnique) - return loaded -} - -func (d *dispatcher) antiUnique(name string) { - d.uniques.Del(name) -} - -func (d *dispatcher) start() { - defer d.buf.Close() - for { - select { - case message, ok := <-d.buf.Get(): - if !ok { - return - } - d.buf.Load() - d.handler(d, message) - - if atomic.AddInt64(&d.msgCount, -1) <= 0 && atomic.LoadUint32(&d.closed) == 1 { - return - } - } - } -} - -func (d *dispatcher) put(message *Message) { - if atomic.CompareAndSwapUint32(&d.closed, 1, 1) { - return - } - atomic.AddInt64(&d.msgCount, 1) - d.buf.Put(message) -} - -func (d *dispatcher) close() { - atomic.CompareAndSwapUint32(&d.closed, 0, 1) - - go func() { - for { - if d.buf.IsClosed() { - return - } - if atomic.LoadInt64(&d.msgCount) <= 0 { - d.buf.Close() - return - } - time.Sleep(time.Second) - } - }() -} diff --git a/server/event.go b/server/event.go index 8268265..be7d267 100644 --- a/server/event.go +++ b/server/event.go @@ -216,6 +216,7 @@ func (slf *event) OnConnectionClosedEvent(conn *Conn, err any) { value(slf.Server, conn, err) return true }) + slf.Server.dispatcherMgr.UnBindProducer(conn.GetID()) }, log.String("Event", "OnConnectionClosedEvent")) } diff --git a/server/internal/dispatcher/dispatcher.go b/server/internal/dispatcher/dispatcher.go new file mode 100644 index 0000000..9e72f03 --- /dev/null +++ b/server/internal/dispatcher/dispatcher.go @@ -0,0 +1,153 @@ +package dispatcher + +import ( + "github.com/alphadose/haxmap" + "github.com/kercylan98/minotaur/utils/buffer" + "sync" + "sync/atomic" +) + +var unique = struct{}{} + +// Handler 消息处理器 +type Handler[P Producer, M Message[P]] func(dispatcher *Dispatcher[P, M], message M) + +// NewDispatcher 生成消息分发器 +func NewDispatcher[P Producer, M Message[P]](bufferSize int, name string, handler Handler[P, M]) *Dispatcher[P, M] { + d := &Dispatcher[P, M]{ + name: name, + buf: buffer.NewRingUnbounded[M](bufferSize), + handler: handler, + uniques: haxmap.New[string, struct{}](), + pmc: make(map[P]int64), + pmcF: make(map[P]func(p P, dispatcher *Dispatcher[P, M])), + abort: make(chan struct{}), + } + return d +} + +// Dispatcher 消息分发器 +type Dispatcher[P Producer, M Message[P]] struct { + buf *buffer.RingUnbounded[M] + uniques *haxmap.Map[string, struct{}] + handler Handler[P, M] + expel bool + mc int64 + pmc map[P]int64 + pmcF map[P]func(p P, dispatcher *Dispatcher[P, M]) + lock sync.RWMutex + name string + closedHandler atomic.Pointer[func(dispatcher *Dispatcher[P, M])] + abort chan struct{} +} + +// SetProducerDoneHandler 设置特定生产者所有消息处理完成时的回调函数 +func (d *Dispatcher[P, M]) SetProducerDoneHandler(p P, handler func(p P, dispatcher *Dispatcher[P, M])) *Dispatcher[P, M] { + d.lock.Lock() + if handler == nil { + delete(d.pmcF, p) + } else { + d.pmcF[p] = handler + } + d.lock.Unlock() + return d +} + +// SetClosedHandler 设置消息分发器关闭时的回调函数 +func (d *Dispatcher[P, M]) SetClosedHandler(handler func(dispatcher *Dispatcher[P, M])) *Dispatcher[P, M] { + d.closedHandler.Store(&handler) + return d +} + +// Name 获取消息分发器名称 +func (d *Dispatcher[P, M]) Name() string { + return d.name +} + +// Unique 设置唯一消息键,返回是否已存在 +func (d *Dispatcher[P, M]) Unique(name string) bool { + _, loaded := d.uniques.GetOrSet(name, unique) + return loaded +} + +// AntiUnique 取消唯一消息键 +func (d *Dispatcher[P, M]) AntiUnique(name string) { + d.uniques.Del(name) +} + +// Expel 设置该消息分发器即将被驱逐,当消息分发器中没有任何消息时,会自动关闭 +func (d *Dispatcher[P, M]) Expel() { + d.lock.Lock() + d.expel = true + if d.mc <= 0 { + d.abort <- struct{}{} + } + d.lock.Unlock() +} + +// UnExpel 取消特定生产者的驱逐计划 +func (d *Dispatcher[P, M]) UnExpel() { + d.lock.Lock() + d.expel = false + d.lock.Unlock() +} + +// IncrCount 主动增量设置特定生产者的消息计数,这在等待异步消息完成后再关闭消息分发器时非常有用 +func (d *Dispatcher[P, M]) IncrCount(producer P, i int64) { + d.lock.Lock() + d.mc += i + d.pmc[producer] += i + if d.expel && d.mc <= 0 { + d.abort <- struct{}{} + } + d.lock.Unlock() +} + +// Put 将消息放入分发器 +func (d *Dispatcher[P, M]) Put(message M) { + d.lock.Lock() + d.mc++ + d.pmc[message.GetProducer()]++ + d.lock.Unlock() + d.buf.Write(message) +} + +// Start 以阻塞的方式开始进行消息分发,当消息分发器中没有任何消息时,会自动关闭 +func (d *Dispatcher[P, M]) Start() *Dispatcher[P, M] { + go func(d *Dispatcher[P, M]) { + process: + for { + select { + case <-d.abort: + d.buf.Close() + break process + case message := <-d.buf.Read(): + d.handler(d, message) + d.lock.Lock() + d.mc-- + p := message.GetProducer() + pmc := d.pmc[p] - 1 + d.pmc[p] = pmc + if f := d.pmcF[p]; f != nil && pmc <= 0 { + go f(p, d) + } + if d.mc <= 0 && d.expel { + d.buf.Close() + break process + } + d.lock.Unlock() + } + } + closedHandler := *(d.closedHandler.Load()) + if closedHandler != nil { + closedHandler(d) + } + close(d.abort) + }(d) + return d +} + +// Closed 判断消息分发器是否已关闭 +func (d *Dispatcher[P, M]) Closed() bool { + return d.buf.Closed() +} diff --git a/server/internal/dispatcher/dispatcher_test.go b/server/internal/dispatcher/dispatcher_test.go new file mode 100644 index 0000000..456cc2c --- /dev/null +++ b/server/internal/dispatcher/dispatcher_test.go @@ -0,0 +1,48 @@ +package dispatcher_test + +import ( + "github.com/kercylan98/minotaur/server/internal/dispatcher" + "sync" + "testing" + "time" +) + +type TestMessage struct { + producer string + v int +} + +func (m *TestMessage) GetProducer() string { + return m.producer +} + +func TestDispatcher_PutStartClose(t *testing.T) { + // 写入完成后,关闭分发器再开始分发,确保消息不会丢失 + w := new(sync.WaitGroup) + cw := new(sync.WaitGroup) + cw.Add(1) + d := dispatcher.NewDispatcher[string, *TestMessage](1024*16, "test", func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) { + t.Log(message) + w.Done() + }).SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) { + t.Log("closed") + cw.Done() + }) + + for i := 0; i < 100; i++ { + w.Add(1) + d.Put(&TestMessage{ + producer: "test", + v: i, + }) + } + + d.Start() + d.Expel() + d.UnExpel() + w.Wait() + time.Sleep(time.Second) + d.Expel() + cw.Wait() + t.Log("done") +} diff --git a/server/internal/dispatcher/manager.go b/server/internal/dispatcher/manager.go new file mode 100644 index 0000000..84662e0 --- /dev/null +++ b/server/internal/dispatcher/manager.go @@ -0,0 +1,153 @@ +package dispatcher + +import ( + "sync" +) + +const SystemName = "*system" + +// NewManager 生成消息分发器管理器 +func NewManager[P Producer, M Message[P]](bufferSize int, handler Handler[P, M]) *Manager[P, M] { + mgr := &Manager[P, M]{ + handler: handler, + dispatchers: make(map[string]*Dispatcher[P, M]), + member: make(map[string]map[P]struct{}), + sys: NewDispatcher(bufferSize, SystemName, handler).Start(), + curr: make(map[P]*Dispatcher[P, M]), + size: bufferSize, + } + + return mgr +} + +// Manager 消息分发器管理器 +type Manager[P Producer, M Message[P]] struct { + handler Handler[P, M] // 消息处理器 + sys *Dispatcher[P, M] // 系统消息分发器 + dispatchers map[string]*Dispatcher[P, M] // 当前所有正在工作的消息分发器 + member map[string]map[P]struct{} // 当前正在工作的消息分发器对应的生产者 + curr map[P]*Dispatcher[P, M] // 当前特定生产者正在使用的消息分发器 + lock sync.RWMutex // 消息分发器锁 + w sync.WaitGroup // 消息分发器等待组 + size int // 消息分发器缓冲区大小 + + closedHandler func(name string) + createdHandler func(name string) +} + +// SetDispatcherClosedHandler 设置消息分发器关闭时的回调函数 +func (m *Manager[P, M]) SetDispatcherClosedHandler(handler func(name string)) *Manager[P, M] { + m.closedHandler = handler + return m +} + +// SetDispatcherCreatedHandler 设置消息分发器创建时的回调函数 +func (m *Manager[P, M]) SetDispatcherCreatedHandler(handler func(name string)) *Manager[P, M] { + m.createdHandler = handler + return m +} + +// HasDispatcher 检查是否存在指定名称的消息分发器 +func (m *Manager[P, M]) HasDispatcher(name string) bool { + m.lock.RLock() + defer m.lock.RUnlock() + _, exist := m.dispatchers[name] + return exist +} + +// GetDispatcherNum 获取当前正在工作的消息分发器数量 +func (m *Manager[P, M]) GetDispatcherNum() int { + m.lock.RLock() + defer m.lock.RUnlock() + return len(m.dispatchers) + 1 // +1 系统消息分发器 +} + +// GetSystemDispatcher 获取系统消息分发器 +func (m *Manager[P, M]) GetSystemDispatcher() *Dispatcher[P, M] { + return m.sys +} + +// GetDispatcher 获取生产者正在使用的消息分发器,如果生产者没有绑定消息分发器,则会返回系统消息分发器 +func (m *Manager[P, M]) GetDispatcher(p P) *Dispatcher[P, M] { + m.lock.Lock() + defer m.lock.Unlock() + + curr, exist := m.curr[p] + if exist { + return curr + } + + return m.sys +} + +// BindProducer 绑定生产者使用特定的消息分发器,如果生产者已经绑定了消息分发器,则会先解绑 +func (m *Manager[P, M]) BindProducer(p P, name string) { + if name == SystemName { + return + } + m.lock.Lock() + defer m.lock.Unlock() + member, exist := m.member[name] + if !exist { + member = make(map[P]struct{}) + m.member[name] = member + } + + if _, exist = member[p]; exist { + d := m.dispatchers[name] + d.SetProducerDoneHandler(p, nil) + d.UnExpel() + + return + } + + curr, exist := m.curr[p] + if exist { + delete(m.member[curr.name], p) + if len(m.member[curr.name]) == 0 { + curr.Expel() + } + } + + dispatcher, exist := m.dispatchers[name] + if !exist { + dispatcher = NewDispatcher(m.size, name, m.handler).SetClosedHandler(func(dispatcher *Dispatcher[P, M]) { + // 消息分发器关闭时,将会将其从管理器中移除 + m.lock.Lock() + delete(m.dispatchers, dispatcher.name) + delete(m.member, dispatcher.name) + m.lock.Unlock() + if m.closedHandler != nil { + m.closedHandler(dispatcher.name) + } + }).Start() + m.dispatchers[name] = dispatcher + defer func(m *Manager[P, M], name string) { + if m.createdHandler != nil { + m.createdHandler(name) + } + }(m, dispatcher.Name()) + } + m.curr[p] = dispatcher + member[p] = struct{}{} +} + +// UnBindProducer 解绑生产者使用特定的消息分发器 +func (m *Manager[P, M]) UnBindProducer(p P) { + m.lock.Lock() + defer m.lock.Unlock() + curr, exist := m.curr[p] + if !exist { + return + } + + curr.SetProducerDoneHandler(p, func(p P, dispatcher *Dispatcher[P, M]) { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.member[dispatcher.name], p) + delete(m.curr, p) + if len(m.member[dispatcher.name]) == 0 { + dispatcher.Expel() + } + }) +} diff --git a/server/internal/dispatcher/manager_test.go b/server/internal/dispatcher/manager_test.go new file mode 100644 index 0000000..ad660b6 --- /dev/null +++ b/server/internal/dispatcher/manager_test.go @@ -0,0 +1,46 @@ +package dispatcher_test + +import ( + "github.com/kercylan98/minotaur/server/internal/dispatcher" + "github.com/kercylan98/minotaur/utils/times" + "testing" + "time" +) + +func TestManager(t *testing.T) { + var mgr *dispatcher.Manager[string, *TestMessage] + var onHandler = func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) { + t.Log(dispatcher.Name(), message, mgr.GetDispatcherNum()) + switch message.v { + case 4: + mgr.UnBindProducer("test") + t.Log("UnBindProducer") + case 6: + mgr.BindProducer(message.GetProducer(), "test-dispatcher") + t.Log("BindProducer") + case 9: + dispatcher.Put(&TestMessage{ + producer: "test", + v: 10, + }) + case 10: + mgr.UnBindProducer("test") + t.Log("UnBindProducer", mgr.GetDispatcherNum()) + } + + } + mgr = dispatcher.NewManager[string, *TestMessage](1024*16, onHandler) + + mgr.BindProducer("test", "test-dispatcher") + for i := 0; i < 10; i++ { + d := mgr.GetDispatcher("test").SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) { + t.Log("closed") + }) + d.Put(&TestMessage{ + producer: "test", + v: i, + }) + } + + time.Sleep(times.Day) +} diff --git a/server/internal/dispatcher/message.go b/server/internal/dispatcher/message.go new file mode 100644 index 0000000..efd4373 --- /dev/null +++ b/server/internal/dispatcher/message.go @@ -0,0 +1,6 @@ +package dispatcher + +type Message[P comparable] interface { + // GetProducer 获取消息生产者 + GetProducer() P +} diff --git a/server/internal/dispatcher/producer.go b/server/internal/dispatcher/producer.go new file mode 100644 index 0000000..92d7325 --- /dev/null +++ b/server/internal/dispatcher/producer.go @@ -0,0 +1,5 @@ +package dispatcher + +type Producer interface { + comparable +} diff --git a/server/message.go b/server/message.go index 67342dc..ed48ecc 100644 --- a/server/message.go +++ b/server/message.go @@ -75,6 +75,7 @@ func HasMessageType(mt MessageType) bool { // Message 服务器消息 type Message struct { + producer string conn *Conn ordinaryHandler func() exceptionHandler func() error @@ -86,6 +87,10 @@ type Message struct { marks []log.Field } +func (slf *Message) GetProducer() string { + return slf.producer +} + // reset 重置消息结构体 func (slf *Message) reset() { slf.conn = nil @@ -126,78 +131,91 @@ func (slf MessageType) String() string { // castToPacketMessage 将消息转换为数据包消息 func (slf *Message) castToPacketMessage(conn *Conn, packet []byte, mark ...log.Field) *Message { + slf.producer = conn.GetID() slf.t, slf.conn, slf.packet, slf.marks = MessageTypePacket, conn, packet, mark return slf } // castToTickerMessage 将消息转换为定时器消息 func (slf *Message) castToTickerMessage(name string, caller func(), mark ...log.Field) *Message { + slf.producer = "sys" slf.t, slf.name, slf.ordinaryHandler, slf.marks = MessageTypeTicker, name, caller, mark return slf } // castToShuntTickerMessage 将消息转换为分发器定时器消息 func (slf *Message) castToShuntTickerMessage(conn *Conn, name string, caller func(), mark ...log.Field) *Message { + slf.producer = conn.GetID() slf.t, slf.conn, slf.name, slf.ordinaryHandler, slf.marks = MessageTypeShuntTicker, conn, name, caller, mark return slf } // castToAsyncMessage 将消息转换为异步消息 func (slf *Message) castToAsyncMessage(caller func() error, callback func(err error), mark ...log.Field) *Message { + slf.producer = "sys" slf.t, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeAsync, caller, callback, mark return slf } // castToAsyncCallbackMessage 将消息转换为异步回调消息 func (slf *Message) castToAsyncCallbackMessage(err error, caller func(err error), mark ...log.Field) *Message { + slf.producer = "sys" slf.t, slf.err, slf.errHandler, slf.marks = MessageTypeAsyncCallback, err, caller, mark return slf } // castToShuntAsyncMessage 将消息转换为分流异步消息 func (slf *Message) castToShuntAsyncMessage(conn *Conn, caller func() error, callback func(err error), mark ...log.Field) *Message { + slf.producer = conn.GetID() slf.t, slf.conn, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeShuntAsync, conn, caller, callback, mark return slf } // castToShuntAsyncCallbackMessage 将消息转换为分流异步回调消息 func (slf *Message) castToShuntAsyncCallbackMessage(conn *Conn, err error, caller func(err error), mark ...log.Field) *Message { + slf.producer = conn.GetID() slf.t, slf.conn, slf.err, slf.errHandler, slf.marks = MessageTypeShuntAsyncCallback, conn, err, caller, mark return slf } // castToUniqueAsyncMessage 将消息转换为唯一异步消息 func (slf *Message) castToUniqueAsyncMessage(unique string, caller func() error, callback func(err error), mark ...log.Field) *Message { + slf.producer = "sys" slf.t, slf.name, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeUniqueAsync, unique, caller, callback, mark return slf } // castToUniqueAsyncCallbackMessage 将消息转换为唯一异步回调消息 func (slf *Message) castToUniqueAsyncCallbackMessage(unique string, err error, caller func(err error), mark ...log.Field) *Message { + slf.producer = "sys" slf.t, slf.name, slf.err, slf.errHandler, slf.marks = MessageTypeUniqueAsyncCallback, unique, err, caller, mark return slf } // castToUniqueShuntAsyncMessage 将消息转换为唯一分流异步消息 func (slf *Message) castToUniqueShuntAsyncMessage(conn *Conn, unique string, caller func() error, callback func(err error), mark ...log.Field) *Message { + slf.producer = conn.GetID() slf.t, slf.conn, slf.name, slf.exceptionHandler, slf.errHandler, slf.marks = MessageTypeUniqueShuntAsync, conn, unique, caller, callback, mark return slf } // castToUniqueShuntAsyncCallbackMessage 将消息转换为唯一分流异步回调消息 func (slf *Message) castToUniqueShuntAsyncCallbackMessage(conn *Conn, unique string, err error, caller func(err error), mark ...log.Field) *Message { + slf.producer = conn.GetID() slf.t, slf.conn, slf.name, slf.err, slf.errHandler, slf.marks = MessageTypeUniqueShuntAsyncCallback, conn, unique, err, caller, mark return slf } // castToSystemMessage 将消息转换为系统消息 func (slf *Message) castToSystemMessage(caller func(), mark ...log.Field) *Message { + slf.producer = "sys" slf.t, slf.ordinaryHandler, slf.marks = MessageTypeSystem, caller, mark return slf } // castToShuntMessage 将消息转换为分流消息 func (slf *Message) castToShuntMessage(conn *Conn, caller func(), mark ...log.Field) *Message { + slf.producer = conn.GetID() slf.t, slf.conn, slf.ordinaryHandler, slf.marks = MessageTypeShunt, conn, caller, mark return slf } diff --git a/server/options.go b/server/options.go index 8e923cc..9497d33 100644 --- a/server/options.go +++ b/server/options.go @@ -32,26 +32,26 @@ type option struct { } type runtime struct { - deadlockDetect time.Duration // 是否开启死锁检测 - supportMessageTypes map[int]bool // websocket 模式下支持的消息类型 - certFile, keyFile string // TLS文件 - tickerPool *timer.Pool // 定时器池 - ticker *timer.Ticker // 定时器 - tickerAutonomy bool // 定时器是否独立运行 - connTickerSize int // 连接定时器大小 - websocketReadDeadline time.Duration // websocket 连接超时时间 - websocketCompression int // websocket 压缩等级 - websocketWriteCompression bool // websocket 写入压缩 - limitLife time.Duration // 限制最大生命周期 - packetWarnSize int // 数据包大小警告 - messageStatisticsDuration time.Duration // 消息统计时长 - messageStatisticsLimit int // 消息统计数量 - messageStatistics []*atomic.Int64 // 消息统计数量 - messageStatisticsLock *sync.RWMutex // 消息统计锁 - connWriteBufferSize int // 连接写入缓冲区大小 - disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道 - websocketUpgrader *websocket.Upgrader // websocket 升级器 - websocketConnInitializer func(writer http.ResponseWriter, request *http.Request, conn *websocket.Conn) error // websocket 连接初始化 + deadlockDetect time.Duration // 是否开启死锁检测 + supportMessageTypes map[int]bool // websocket 模式下支持的消息类型 + certFile, keyFile string // TLS文件 + tickerPool *timer.Pool // 定时器池 + ticker *timer.Ticker // 定时器 + tickerAutonomy bool // 定时器是否独立运行 + connTickerSize int // 连接定时器大小 + websocketReadDeadline time.Duration // websocket 连接超时时间 + websocketCompression int // websocket 压缩等级 + websocketWriteCompression bool // websocket 写入压缩 + limitLife time.Duration // 限制最大生命周期 + packetWarnSize int // 数据包大小警告 + messageStatisticsDuration time.Duration // 消息统计时长 + messageStatisticsLimit int // 消息统计数量 + messageStatistics []*atomic.Int64 // 消息统计数量 + messageStatisticsLock *sync.RWMutex // 消息统计锁 + connWriteBufferSize int // 连接写入缓冲区大小 + websocketUpgrader *websocket.Upgrader // websocket 升级器 + websocketConnInitializer func(writer http.ResponseWriter, request *http.Request, conn *websocket.Conn) error // websocket 连接初始化 + dispatcherBufferSize int // 消息分发器缓冲区大小 } // WithWebsocketConnInitializer 通过 websocket 连接初始化的方式创建服务器,当 initializer 返回错误时,服务器将不会处理该连接的后续逻辑 @@ -77,14 +77,6 @@ func WithWebsocketUpgrade(upgrader *websocket.Upgrader) Option { } } -// WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器 -// - 默认不开启,当禁用自动释放分流渠道时,服务器将不会在连接断开时自动释放分流渠道,需要手动调用 ReleaseShunt 方法释放 -func WithDisableAutomaticReleaseShunt() Option { - return func(srv *Server) { - srv.runtime.disableAutomaticReleaseShunt = true - } -} - // WithConnWriteBufferSize 通过连接写入缓冲区大小的方式创建服务器 // - 默认值为 DefaultConnWriteBufferSize // - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存 @@ -100,14 +92,14 @@ func WithConnWriteBufferSize(size int) Option { // WithDispatcherBufferSize 通过消息分发器缓冲区大小的方式创建服务器 // - 默认值为 DefaultDispatcherBufferSize // - 设置合适的缓冲区大小可以提高服务器性能,但是会占用更多的内存 -//func WithDispatcherBufferSize(size int) Option { -// return func(srv *Server) { -// if size <= 0 { -// return -// } -// srv.dispatcherBufferSize = size -// } -//} +func WithDispatcherBufferSize(size int) Option { + return func(srv *Server) { + if size <= 0 { + return + } + srv.dispatcherBufferSize = size + } +} // WithMessageStatistics 通过消息统计的方式创建服务器 // - 默认不开启,当 duration 和 limit 均大于 0 的时候,服务器将记录每 duration 期间的消息数量,并保留最多 limit 条 diff --git a/server/server.go b/server/server.go index 33d7064..38503c0 100644 --- a/server/server.go +++ b/server/server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/kercylan98/minotaur/server/internal/dispatcher" "github.com/kercylan98/minotaur/server/internal/logger" "github.com/kercylan98/minotaur/utils/concurrent" "github.com/kercylan98/minotaur/utils/log" @@ -21,7 +22,6 @@ import ( "os" "os/signal" "runtime/debug" - "sync" "sync/atomic" "syscall" "time" @@ -32,17 +32,15 @@ func New(network Network, options ...Option) *Server { network.check() server := &Server{ runtime: &runtime{ - packetWarnSize: DefaultPacketWarnSize, - connWriteBufferSize: DefaultConnWriteBufferSize, + packetWarnSize: DefaultPacketWarnSize, + connWriteBufferSize: DefaultConnWriteBufferSize, + dispatcherBufferSize: DefaultDispatcherBufferSize, }, - hub: &hub{}, - option: &option{}, - network: network, - closeChannel: make(chan struct{}, 1), - systemSignal: make(chan os.Signal, 1), - dispatchers: make(map[string]*dispatcher), - dispatcherMember: map[string]map[string]*Conn{}, - currDispatcher: map[string]*dispatcher{}, + hub: &hub{}, + option: &option{}, + network: network, + closeChannel: make(chan struct{}, 1), + systemSignal: make(chan os.Signal, 1), } server.ctx, server.cancel = context.WithCancel(context.Background()) server.event = newEvent(server) @@ -67,32 +65,29 @@ func New(network Network, options ...Option) *Server { // Server 网络服务器 type Server struct { - *event // 事件 - *runtime // 运行时 - *option // 可选项 - *hub // 连接集合 - ginServer *gin.Engine // HTTP模式下的路由器 - httpServer *http.Server // HTTP模式下的服务器 - grpcServer *grpc.Server // GRPC模式下的服务器 - gServer *gNet // TCP或UDP模式下的服务器 - multiple *MultipleServer // 多服务器模式下的服务器 - ants *ants.Pool // 协程池 - messagePool *concurrent.Pool[*Message] // 消息池 - ctx context.Context // 上下文 - cancel context.CancelFunc // 停止上下文 - systemDispatcher *dispatcher // 系统消息分发器 - systemSignal chan os.Signal // 系统信号 - closeChannel chan struct{} // 关闭信号 - multipleRuntimeErrorChan chan error // 多服务器模式下的运行时错误 - dispatchers map[string]*dispatcher // 消息分发器集合 - dispatcherMember map[string]map[string]*Conn // 消息分发器包含的连接 - currDispatcher map[string]*dispatcher // 当前连接所处消息分发器 - dispatcherLock sync.RWMutex // 消息分发器锁 - messageCounter atomic.Int64 // 消息计数器 - addr string // 侦听地址 - network Network // 网络类型 - closed uint32 // 服务器是否已关闭 - services []func() // 服务 + *event // 事件 + *runtime // 运行时 + *option // 可选项 + *hub // 连接集合 + dispatcherMgr *dispatcher.Manager[string, *Message] // 消息分发器管理器 + ginServer *gin.Engine // HTTP模式下的路由器 + httpServer *http.Server // HTTP模式下的服务器 + grpcServer *grpc.Server // GRPC模式下的服务器 + gServer *gNet // TCP或UDP模式下的服务器 + multiple *MultipleServer // 多服务器模式下的服务器 + ants *ants.Pool // 协程池 + messagePool *concurrent.Pool[*Message] // 消息池 + ctx context.Context // 上下文 + cancel context.CancelFunc // 停止上下文 + systemSignal chan os.Signal // 系统信号 + closeChannel chan struct{} // 关闭信号 + multipleRuntimeErrorChan chan error // 多服务器模式下的运行时错误 + + messageCounter atomic.Int64 // 消息计数器 + addr string // 侦听地址 + network Network // 网络类型 + closed uint32 // 服务器是否已关闭 + services []func() // 服务 } // preCheckAndAdaptation 预检查及适配 @@ -221,13 +216,6 @@ func (srv *Server) shutdown(err error) { srv.ants.Release() srv.ants = nil } - srv.dispatcherLock.Lock() - for s, d := range srv.dispatchers { - srv.OnShuntChannelClosedEvent(d.name) - d.close() - delete(srv.dispatchers, s) - } - srv.dispatcherLock.Unlock() if srv.grpcServer != nil { srv.grpcServer.GracefulStop() } @@ -300,107 +288,27 @@ func (srv *Server) GetMessageCount() int64 { // UseShunt 切换连接所使用的消息分流渠道,当分流渠道 name 不存在时将会创建一个新的分流渠道,否则将会加入已存在的分流渠道 // - 默认情况下,所有连接都使用系统通道进行消息分发,当指定消息分流渠道时,将会使用指定的消息分流渠道进行消息分发 -// - 在使用 WithDisableAutomaticReleaseShunt 创建服务器后,必须始终在连接不再使用后主动通过 ReleaseShunt 释放消息分流渠道,否则将造成内存泄漏 +// - 分流渠道会在连接断开时标记为驱逐状态,当分流渠道中的所有消息处理完毕且没有新连接使用时,将会被清除 +// +// 一些有趣的情况: +// - 当连接发送异步消息时,消息会被分为两部分,分别是异步部分和回调部分。异步部分会在当前的分流渠道中处理,而回调部分则是根据回调时所在的分流渠道进行处理 func (srv *Server) UseShunt(conn *Conn, name string) { - srv.dispatcherLock.Lock() - defer srv.dispatcherLock.Unlock() - d, exist := srv.dispatchers[name] - if !exist { - d = generateDispatcher(name, srv.dispatchMessage) - srv.OnShuntChannelCreatedEvent(d.name) - go d.start() - srv.dispatchers[name] = d - } - - curr, exist := srv.currDispatcher[conn.GetID()] - if exist { - if curr.name == name { - return - } - - delete(srv.dispatcherMember[curr.name], conn.GetID()) - if curr.name != serverSystemDispatcher && len(srv.dispatcherMember[curr.name]) == 0 { - delete(srv.dispatchers, curr.name) - srv.OnShuntChannelClosedEvent(d.name) - curr.close() - } - } - srv.currDispatcher[conn.GetID()] = d - - member, exist := srv.dispatcherMember[name] - if !exist { - member = map[string]*Conn{} - srv.dispatcherMember[name] = member - } - - member[conn.GetID()] = conn + srv.dispatcherMgr.BindProducer(conn.GetID(), name) } // HasShunt 检查特定消息分流渠道是否存在 func (srv *Server) HasShunt(name string) bool { - srv.dispatcherLock.RLock() - defer srv.dispatcherLock.RUnlock() - _, exist := srv.dispatchers[name] - return exist + return srv.dispatcherMgr.HasDispatcher(name) } // GetConnCurrShunt 获取连接当前所使用的消息分流渠道 func (srv *Server) GetConnCurrShunt(conn *Conn) string { - srv.dispatcherLock.RLock() - defer srv.dispatcherLock.RUnlock() - d, exist := srv.currDispatcher[conn.GetID()] - if exist { - return d.name - } - return serverSystemDispatcher + return srv.dispatcherMgr.GetDispatcher(conn.GetID()).Name() } // GetShuntNum 获取消息分流渠道数量 func (srv *Server) GetShuntNum() int { - srv.dispatcherLock.RLock() - defer srv.dispatcherLock.RUnlock() - return len(srv.dispatchers) -} - -// getConnDispatcher 获取连接所使用的消息分发器 -func (srv *Server) getConnDispatcher(conn *Conn) *dispatcher { - if conn == nil { - return srv.systemDispatcher - } - srv.dispatcherLock.RLock() - defer srv.dispatcherLock.RUnlock() - d, exist := srv.currDispatcher[conn.GetID()] - if exist { - return d - } - return srv.systemDispatcher -} - -// ReleaseShunt 释放分流渠道中的连接,当分流渠道中不再存在连接时将会自动释放分流渠道 -// - 在未使用 WithDisableAutomaticReleaseShunt 选项时,当连接关闭时将会自动释放分流渠道中连接的资源占用 -// - 若执行过程中连接正在使用,将会切换至系统通道 -func (srv *Server) ReleaseShunt(conn *Conn) { - srv.releaseDispatcher(conn) -} - -// releaseDispatcher 关闭消息分发器 -func (srv *Server) releaseDispatcher(conn *Conn) { - if conn == nil { - return - } - cid := conn.GetID() - srv.dispatcherLock.Lock() - defer srv.dispatcherLock.Unlock() - d, exist := srv.currDispatcher[cid] - if exist && d.name != serverSystemDispatcher { - delete(srv.dispatcherMember[d.name], cid) - if len(srv.dispatcherMember[d.name]) == 0 { - srv.OnShuntChannelClosedEvent(d.name) - d.close() - delete(srv.dispatchers, d.name) - } - delete(srv.currDispatcher, cid) - } + return srv.dispatcherMgr.GetDispatcherNum() } // pushMessage 向服务器中写入特定类型的消息,需严格遵守消息属性要求 @@ -409,25 +317,29 @@ func (srv *Server) pushMessage(message *Message) { srv.messagePool.Release(message) return } - var dispatcher *dispatcher + var d *dispatcher.Dispatcher[string, *Message] switch message.t { case MessageTypePacket, MessageTypeShuntTicker, MessageTypeShuntAsync, MessageTypeShuntAsyncCallback, MessageTypeUniqueShuntAsync, MessageTypeUniqueShuntAsyncCallback, MessageTypeShunt: - dispatcher = srv.getConnDispatcher(message.conn) + d = srv.dispatcherMgr.GetDispatcher(message.conn.GetID()) case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeTicker: - dispatcher = srv.systemDispatcher + d = srv.dispatcherMgr.GetSystemDispatcher() } - if dispatcher == nil { + if d == nil { return } - if (message.t == MessageTypeUniqueShuntAsync || message.t == MessageTypeUniqueAsync) && dispatcher.unique(message.name) { + if (message.t == MessageTypeUniqueShuntAsync || message.t == MessageTypeUniqueAsync) && d.Unique(message.name) { srv.messagePool.Release(message) return } + switch message.t { + case MessageTypeShuntAsync, MessageTypeUniqueShuntAsync: + d.IncrCount(message.conn.GetID(), 1) + } srv.hitMessageStatistics() - dispatcher.put(message) + d.Put(message) } func (srv *Server) low(message *Message, present time.Time, expect time.Duration, messageReplace ...string) { @@ -456,7 +368,7 @@ func (srv *Server) low(message *Message, present time.Time, expect time.Duration } // dispatchMessage 消息分发 -func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) { +func (srv *Server) dispatchMessage(dispatcherIns *dispatcher.Dispatcher[string, *Message], msg *Message) { var ( ctx context.Context cancel context.CancelFunc @@ -476,7 +388,7 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) { present := time.Now() if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync { - defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher, msg *Message, present time.Time) { + defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher.Dispatcher[string, *Message], msg *Message, present time.Time) { super.Handle(cancel) if err := super.RecoverTransform(recover()); err != nil { stack := string(debug.Stack()) @@ -485,7 +397,7 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) { srv.OnMessageErrorEvent(msg, err) } if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback { - dispatcherIns.antiUnique(msg.name) + dispatcherIns.AntiUnique(msg.name) } srv.low(msg, present, time.Millisecond*100) @@ -512,10 +424,14 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) { msg.ordinaryHandler() case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync: if err := srv.ants.Submit(func() { - defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher, msg *Message, present time.Time) { + defer func(cancel context.CancelFunc, srv *Server, dispatcherIns *dispatcher.Dispatcher[string, *Message], msg *Message, present time.Time) { + switch msg.t { + case MessageTypeShuntAsync, MessageTypeUniqueShuntAsync: + dispatcherIns.IncrCount(msg.conn.GetID(), -1) + } if err := super.RecoverTransform(recover()); err != nil { if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync { - dispatcherIns.antiUnique(msg.name) + dispatcherIns.AntiUnique(msg.name) } stack := string(debug.Stack()) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack)) @@ -550,7 +466,7 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher, msg *Message) { srv.PushShuntAsyncCallbackMessage(msg.conn, err, msg.errHandler) return } - dispatcherIns.antiUnique(msg.name) + dispatcherIns.AntiUnique(msg.name) if err != nil { log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", string(debug.Stack()))) } @@ -769,7 +685,8 @@ func onMessageSystemInit(srv *Server) { }, ) srv.startMessageStatistics() - srv.systemDispatcher = generateDispatcher(serverSystemDispatcher, srv.dispatchMessage) - go srv.systemDispatcher.start() + srv.dispatcherMgr = dispatcher.NewManager[string, *Message](srv.dispatcherBufferSize, srv.dispatchMessage). + SetDispatcherCreatedHandler(srv.OnShuntChannelCreatedEvent). + SetDispatcherClosedHandler(srv.OnShuntChannelClosedEvent) srv.OnMessageReadyEvent() } diff --git a/server/server_test.go b/server/server_test.go index 006ae7a..7e8fe1f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -18,7 +18,7 @@ func TestNew(t *testing.T) { fmt.Println("启动完成") }) srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) { - fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount()) + fmt.Println("关闭", conn.GetID(), err, "IncrCount", srv.GetOnlineCount()) }) srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { @@ -38,7 +38,7 @@ func TestNew2(t *testing.T) { fmt.Println("启动完成") }) srv.RegConnectionClosedEvent(func(srv *server.Server, conn *server.Conn, err any) { - fmt.Println("关闭", conn.GetID(), err, "Count", srv.GetOnlineCount()) + fmt.Println("关闭", conn.GetID(), err, "IncrCount", srv.GetOnlineCount()) }) srv.RegConnectionReceivePacketEvent(func(srv *server.Server, conn *server.Conn, packet []byte) { diff --git a/utils/buffer/ring.go b/utils/buffer/ring.go index 8c4b3df..2d90929 100644 --- a/utils/buffer/ring.go +++ b/utils/buffer/ring.go @@ -1,15 +1,21 @@ package buffer -// NewRing 创建一个环形缓冲区 -func NewRing[T any](initSize int) *Ring[T] { - if initSize <= 1 { - panic("initial size must be great than one") +// NewRing 创建一个并发不安全的环形缓冲区 +// - initSize: 初始容量 +// +// 当初始容量小于 2 或未设置时,将会使用默认容量 2 +func NewRing[T any](initSize ...int) *Ring[T] { + if len(initSize) == 0 { + initSize = append(initSize, 2) + } + if initSize[0] < 2 { + initSize[0] = 2 } return &Ring[T]{ - buf: make([]T, initSize), - initSize: initSize, - size: initSize, + buf: make([]T, initSize[0]), + initSize: initSize[0], + size: initSize[0], } } @@ -23,91 +29,120 @@ type Ring[T any] struct { } // Read 读取数据 -func (slf *Ring[T]) Read() (T, error) { +func (b *Ring[T]) Read() (T, error) { var t T - if slf.r == slf.w { + if b.r == b.w { return t, ErrBufferIsEmpty } - v := slf.buf[slf.r] - slf.r++ - if slf.r == slf.size { - slf.r = 0 + v := b.buf[b.r] + b.r++ + if b.r == b.size { + b.r = 0 } return v, nil } +// ReadAll 读取所有数据 +func (b *Ring[T]) ReadAll() []T { + if b.r == b.w { + return nil // 没有数据时返回空切片 + } + + var length int + var data []T + + if b.w > b.r { + length = b.w - b.r + } else { + length = len(b.buf) - b.r + b.w + } + data = make([]T, length) // 预分配空间 + + if b.w > b.r { + copy(data, b.buf[b.r:b.w]) + } else { + copied := copy(data, b.buf[b.r:]) + copy(data[copied:], b.buf[:b.w]) + } + + b.r = 0 + b.w = 0 + + return data +} + // Peek 查看数据 -func (slf *Ring[T]) Peek() (t T, err error) { - if slf.r == slf.w { +func (b *Ring[T]) Peek() (t T, err error) { + if b.r == b.w { return t, ErrBufferIsEmpty } - return slf.buf[slf.r], nil + return b.buf[b.r], nil } // Write 写入数据 -func (slf *Ring[T]) Write(v T) { - slf.buf[slf.w] = v - slf.w++ +func (b *Ring[T]) Write(v T) { + b.buf[b.w] = v + b.w++ - if slf.w == slf.size { - slf.w = 0 + if b.w == b.size { + b.w = 0 } - if slf.w == slf.r { - slf.grow() + if b.w == b.r { + b.grow() } } // grow 扩容 -func (slf *Ring[T]) grow() { +func (b *Ring[T]) grow() { var size int - if slf.size < 1024 { - size = slf.size * 2 + if b.size < 1024 { + size = b.size * 2 } else { - size = slf.size + slf.size/4 + size = b.size + b.size/4 } buf := make([]T, size) - copy(buf[0:], slf.buf[slf.r:]) - copy(buf[slf.size-slf.r:], slf.buf[0:slf.r]) + copy(buf[0:], b.buf[b.r:]) + copy(buf[b.size-b.r:], b.buf[0:b.r]) - slf.r = 0 - slf.w = slf.size - slf.size = size - slf.buf = buf + b.r = 0 + b.w = b.size + b.size = size + b.buf = buf } // IsEmpty 是否为空 -func (slf *Ring[T]) IsEmpty() bool { - return slf.r == slf.w +func (b *Ring[T]) IsEmpty() bool { + return b.r == b.w } // Cap 返回缓冲区容量 -func (slf *Ring[T]) Cap() int { - return slf.size +func (b *Ring[T]) Cap() int { + return b.size } // Len 返回缓冲区长度 -func (slf *Ring[T]) Len() int { - if slf.r == slf.w { +func (b *Ring[T]) Len() int { + if b.r == b.w { return 0 } - if slf.w > slf.r { - return slf.w - slf.r + if b.w > b.r { + return b.w - b.r } - return slf.size - slf.r + slf.w + return b.size - b.r + b.w } // Reset 重置缓冲区 -func (slf *Ring[T]) Reset() { - slf.r = 0 - slf.w = 0 - slf.size = slf.initSize - slf.buf = make([]T, slf.initSize) +func (b *Ring[T]) Reset() { + b.r = 0 + b.w = 0 + b.size = b.initSize + b.buf = make([]T, b.initSize) } diff --git a/utils/buffer/ring_benchmark_test.go b/utils/buffer/ring_benchmark_test.go new file mode 100644 index 0000000..4d9d11c --- /dev/null +++ b/utils/buffer/ring_benchmark_test.go @@ -0,0 +1,25 @@ +package buffer_test + +import ( + "github.com/kercylan98/minotaur/utils/buffer" + "testing" +) + +func BenchmarkRingWrite(b *testing.B) { + ring := buffer.NewRing[int](1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + ring.Write(i) + } +} + +func BenchmarkRingRead(b *testing.B) { + ring := buffer.NewRing[int](1024) + for i := 0; i < b.N; i++ { + ring.Write(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ring.Read() + } +} diff --git a/utils/buffer/ring_test.go b/utils/buffer/ring_test.go new file mode 100644 index 0000000..5461102 --- /dev/null +++ b/utils/buffer/ring_test.go @@ -0,0 +1,14 @@ +package buffer_test + +import ( + "github.com/kercylan98/minotaur/utils/buffer" + "testing" +) + +func TestNewRing(t *testing.T) { + ring := buffer.NewRing[int]() + for i := 0; i < 100; i++ { + ring.Write(i) + t.Log(ring.Read()) + } +} diff --git a/utils/buffer/ring_unbounded.go b/utils/buffer/ring_unbounded.go new file mode 100644 index 0000000..c0095fe --- /dev/null +++ b/utils/buffer/ring_unbounded.go @@ -0,0 +1,100 @@ +package buffer + +import ( + "sync" +) + +// NewRingUnbounded 创建一个并发安全的基于环形缓冲区实现的无界缓冲区 +func NewRingUnbounded[T any](bufferSize int) *RingUnbounded[T] { + ru := &RingUnbounded[T]{ + ring: NewRing[T](1024), + rc: make(chan T, bufferSize), + closedSignal: make(chan struct{}), + } + ru.cond = sync.NewCond(&ru.rrm) + + ru.process() + return ru +} + +// RingUnbounded 基于环形缓冲区实现的无界缓冲区 +type RingUnbounded[T any] struct { + ring *Ring[T] + rrm sync.Mutex + cond *sync.Cond + rc chan T + closed bool + closedMutex sync.RWMutex + closedSignal chan struct{} +} + +// Write 写入数据 +func (b *RingUnbounded[T]) Write(v T) { + b.closedMutex.RLock() + defer b.closedMutex.RUnlock() + if b.closed { + return + } + + b.rrm.Lock() + b.ring.Write(v) + b.cond.Signal() + b.rrm.Unlock() +} + +// Read 读取数据 +func (b *RingUnbounded[T]) Read() <-chan T { + return b.rc +} + +// Closed 判断缓冲区是否已关闭 +func (b *RingUnbounded[T]) Closed() bool { + b.closedMutex.RLock() + defer b.closedMutex.RUnlock() + return b.closed +} + +// Close 关闭缓冲区,关闭后将不再接收新数据,但是已有数据仍然可以读取 +func (b *RingUnbounded[T]) Close() <-chan struct{} { + b.closedMutex.Lock() + defer b.closedMutex.Unlock() + if b.closed { + return b.closedSignal + } + b.closed = true + + b.rrm.Lock() + b.cond.Signal() + b.rrm.Unlock() + return b.closedSignal +} + +func (b *RingUnbounded[T]) process() { + go func(b *RingUnbounded[T]) { + for { + b.closedMutex.RLock() + b.rrm.Lock() + vs := b.ring.ReadAll() + if len(vs) == 0 && !b.closed { + b.closedMutex.RUnlock() + b.cond.Wait() + } else { + b.closedMutex.RUnlock() + } + b.rrm.Unlock() + + b.closedMutex.RLock() + if b.closed && len(vs) == 0 { + close(b.rc) + close(b.closedSignal) + b.closedMutex.RUnlock() + break + } + + for _, v := range vs { + b.rc <- v + } + b.closedMutex.RUnlock() + } + }(b) +} diff --git a/utils/buffer/ring_unbounded_benchmark_test.go b/utils/buffer/ring_unbounded_benchmark_test.go new file mode 100644 index 0000000..4f983db --- /dev/null +++ b/utils/buffer/ring_unbounded_benchmark_test.go @@ -0,0 +1,48 @@ +package buffer_test + +import ( + "github.com/kercylan98/minotaur/utils/buffer" + "testing" +) + +func BenchmarkRingUnbounded_Write(b *testing.B) { + ring := buffer.NewRingUnbounded[int](1024 * 16) + b.ResetTimer() + for i := 0; i < b.N; i++ { + ring.Write(i) + } +} + +func BenchmarkRingUnbounded_Read(b *testing.B) { + ring := buffer.NewRingUnbounded[int](1024 * 16) + for i := 0; i < b.N; i++ { + ring.Write(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + <-ring.Read() + } +} + +func BenchmarkRingUnbounded_Write_Parallel(b *testing.B) { + ring := buffer.NewRingUnbounded[int](1024 * 16) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ring.Write(1) + } + }) +} + +func BenchmarkRingUnbounded_Read_Parallel(b *testing.B) { + ring := buffer.NewRingUnbounded[int](1024 * 16) + for i := 0; i < b.N; i++ { + ring.Write(i) + } + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + <-ring.Read() + } + }) +} diff --git a/utils/buffer/ring_unbounded_test.go b/utils/buffer/ring_unbounded_test.go new file mode 100644 index 0000000..d7b6959 --- /dev/null +++ b/utils/buffer/ring_unbounded_test.go @@ -0,0 +1,33 @@ +package buffer_test + +import ( + "github.com/kercylan98/minotaur/utils/buffer" + "testing" +) + +func TestRingUnbounded_Write2Read(t *testing.T) { + ring := buffer.NewRingUnbounded[int](1024 * 16) + for i := 0; i < 100; i++ { + ring.Write(i) + } + t.Log("write done") + for i := 0; i < 100; i++ { + t.Log(<-ring.Read()) + } + t.Log("read done") +} + +func TestRingUnbounded_Close(t *testing.T) { + ring := buffer.NewRingUnbounded[int](1024 * 16) + for i := 0; i < 100; i++ { + ring.Write(i) + } + t.Log("write done") + ring.Close() + t.Log("close done") + for v := range ring.Read() { + ring.Write(v) + t.Log(v) + } + t.Log("read done") +} diff --git a/utils/buffer/unbounded_benchmark_test.go b/utils/buffer/unbounded_benchmark_test.go index 99c5803..6ee69a7 100644 --- a/utils/buffer/unbounded_benchmark_test.go +++ b/utils/buffer/unbounded_benchmark_test.go @@ -5,15 +5,46 @@ import ( "testing" ) -func BenchmarkUnboundedBuffer(b *testing.B) { - ub := buffer.NewUnbounded[int]() +func BenchmarkUnbounded_Write(b *testing.B) { + u := buffer.NewUnbounded[int]() + b.ResetTimer() + for i := 0; i < b.N; i++ { + u.Put(i) + } +} +func BenchmarkUnbounded_Read(b *testing.B) { + u := buffer.NewUnbounded[int]() + for i := 0; i < b.N; i++ { + u.Put(i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + <-u.Get() + u.Load() + } +} + +func BenchmarkUnbounded_Write_Parallel(b *testing.B) { + u := buffer.NewUnbounded[int]() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - ub.Put(1) - <-ub.Get() - ub.Load() + u.Put(1) + } + }) +} + +func BenchmarkUnbounded_Read_Parallel(b *testing.B) { + u := buffer.NewUnbounded[int]() + for i := 0; i < b.N; i++ { + u.Put(i) + } + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + <-u.Get() + u.Load() } }) }