test: 完善 dispatcher.Dispatcher 注释及测试用例
This commit is contained in:
parent
c439ef6424
commit
a2a9199d41
|
@ -125,7 +125,7 @@ type connection struct {
|
|||
gw func(packet []byte)
|
||||
data map[any]any
|
||||
closed bool
|
||||
pool *hub.Pool[*connPacket]
|
||||
pool *hub.ObjectPool[*connPacket]
|
||||
loop writeloop.WriteLoop[*connPacket]
|
||||
mu sync.Mutex
|
||||
openTime time.Time
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
type hub struct {
|
||||
type connMgr struct {
|
||||
connections map[string]*Conn // 所有连接
|
||||
|
||||
register chan *Conn // 注册连接
|
||||
|
@ -26,12 +26,12 @@ type hubBroadcast struct {
|
|||
filter func(conn *Conn) bool // 过滤掉返回 false 的连接
|
||||
}
|
||||
|
||||
func (h *hub) run(ctx context.Context) {
|
||||
func (h *connMgr) run(ctx context.Context) {
|
||||
h.connections = make(map[string]*Conn)
|
||||
h.register = make(chan *Conn, DefaultConnHubBufferSize)
|
||||
h.unregister = make(chan string, DefaultConnHubBufferSize)
|
||||
h.broadcast = make(chan hubBroadcast, DefaultConnHubBufferSize)
|
||||
go func(ctx context.Context, h *hub) {
|
||||
go func(ctx context.Context, h *connMgr) {
|
||||
for {
|
||||
select {
|
||||
case conn := <-h.register:
|
||||
|
@ -54,7 +54,7 @@ func (h *hub) run(ctx context.Context) {
|
|||
}
|
||||
|
||||
// registerConn 注册连接
|
||||
func (h *hub) registerConn(conn *Conn) {
|
||||
func (h *connMgr) registerConn(conn *Conn) {
|
||||
select {
|
||||
case h.register <- conn:
|
||||
default:
|
||||
|
@ -63,7 +63,7 @@ func (h *hub) registerConn(conn *Conn) {
|
|||
}
|
||||
|
||||
// unregisterConn 注销连接
|
||||
func (h *hub) unregisterConn(id string) {
|
||||
func (h *connMgr) unregisterConn(id string) {
|
||||
select {
|
||||
case h.unregister <- id:
|
||||
default:
|
||||
|
@ -72,21 +72,21 @@ func (h *hub) unregisterConn(id string) {
|
|||
}
|
||||
|
||||
// GetOnlineCount 获取在线人数
|
||||
func (h *hub) GetOnlineCount() int {
|
||||
func (h *connMgr) GetOnlineCount() int {
|
||||
h.chanMutex.RLock()
|
||||
defer h.chanMutex.RUnlock()
|
||||
return h.onlineCount
|
||||
}
|
||||
|
||||
// GetOnlineBotCount 获取在线机器人数量
|
||||
func (h *hub) GetOnlineBotCount() int {
|
||||
func (h *connMgr) GetOnlineBotCount() int {
|
||||
h.chanMutex.RLock()
|
||||
defer h.chanMutex.RUnlock()
|
||||
return h.botCount
|
||||
}
|
||||
|
||||
// IsOnline 是否在线
|
||||
func (h *hub) IsOnline(id string) bool {
|
||||
func (h *connMgr) IsOnline(id string) bool {
|
||||
h.chanMutex.RLock()
|
||||
_, exist := h.connections[id]
|
||||
h.chanMutex.RUnlock()
|
||||
|
@ -94,7 +94,7 @@ func (h *hub) IsOnline(id string) bool {
|
|||
}
|
||||
|
||||
// GetOnlineAll 获取所有在线连接
|
||||
func (h *hub) GetOnlineAll() map[string]*Conn {
|
||||
func (h *connMgr) GetOnlineAll() map[string]*Conn {
|
||||
h.chanMutex.RLock()
|
||||
cop := collection.CloneMap(h.connections)
|
||||
h.chanMutex.RUnlock()
|
||||
|
@ -102,7 +102,7 @@ func (h *hub) GetOnlineAll() map[string]*Conn {
|
|||
}
|
||||
|
||||
// GetOnline 获取在线连接
|
||||
func (h *hub) GetOnline(id string) *Conn {
|
||||
func (h *connMgr) GetOnline(id string) *Conn {
|
||||
h.chanMutex.RLock()
|
||||
conn := h.connections[id]
|
||||
h.chanMutex.RUnlock()
|
||||
|
@ -110,7 +110,7 @@ func (h *hub) GetOnline(id string) *Conn {
|
|||
}
|
||||
|
||||
// CloseConn 关闭连接
|
||||
func (h *hub) CloseConn(id string) {
|
||||
func (h *connMgr) CloseConn(id string) {
|
||||
h.chanMutex.RLock()
|
||||
conn := h.connections[id]
|
||||
h.chanMutex.RUnlock()
|
||||
|
@ -120,7 +120,7 @@ func (h *hub) CloseConn(id string) {
|
|||
}
|
||||
|
||||
// Broadcast 广播消息
|
||||
func (h *hub) Broadcast(packet []byte, filter ...func(conn *Conn) bool) {
|
||||
func (h *connMgr) Broadcast(packet []byte, filter ...func(conn *Conn) bool) {
|
||||
m := hubBroadcast{
|
||||
packet: packet,
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ func (h *hub) Broadcast(packet []byte, filter ...func(conn *Conn) bool) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *hub) onRegister(conn *Conn) {
|
||||
func (h *connMgr) onRegister(conn *Conn) {
|
||||
h.chanMutex.Lock()
|
||||
if h.closed {
|
||||
conn.Close()
|
||||
|
@ -148,7 +148,7 @@ func (h *hub) onRegister(conn *Conn) {
|
|||
h.chanMutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *hub) onUnregister(id string) {
|
||||
func (h *connMgr) onUnregister(id string) {
|
||||
h.chanMutex.Lock()
|
||||
if conn, ok := h.connections[id]; ok {
|
||||
h.onlineCount--
|
||||
|
@ -160,7 +160,7 @@ func (h *hub) onUnregister(id string) {
|
|||
h.chanMutex.Unlock()
|
||||
}
|
||||
|
||||
func (h *hub) onBroadcast(packet hubBroadcast) {
|
||||
func (h *connMgr) onBroadcast(packet hubBroadcast) {
|
||||
h.chanMutex.RLock()
|
||||
defer h.chanMutex.RUnlock()
|
||||
for _, conn := range h.connections {
|
|
@ -1,6 +1,7 @@
|
|||
package dispatcher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/alphadose/haxmap"
|
||||
"github.com/kercylan98/minotaur/utils/buffer"
|
||||
"sync"
|
||||
|
@ -12,8 +13,11 @@ var unique = struct{}{}
|
|||
// Handler 消息处理器
|
||||
type Handler[P Producer, M Message[P]] func(dispatcher *Dispatcher[P, M], message M)
|
||||
|
||||
// NewDispatcher 生成消息分发器
|
||||
// NewDispatcher 创建一个新的消息分发器 Dispatcher 实例
|
||||
func NewDispatcher[P Producer, M Message[P]](bufferSize int, name string, handler Handler[P, M]) *Dispatcher[P, M] {
|
||||
if bufferSize <= 0 || handler == nil {
|
||||
panic(fmt.Errorf("bufferSize must be greater than 0 and handler must not be nil, but got bufferSize: %d, handler is nil: %v", bufferSize, handler == nil))
|
||||
}
|
||||
d := &Dispatcher[P, M]{
|
||||
name: name,
|
||||
buf: buffer.NewRingUnbounded[M](bufferSize),
|
||||
|
@ -26,7 +30,19 @@ func NewDispatcher[P Producer, M Message[P]](bufferSize int, name string, handle
|
|||
return d
|
||||
}
|
||||
|
||||
// Dispatcher 消息分发器
|
||||
// Dispatcher 用于服务器消息处理的消息分发器
|
||||
//
|
||||
// 这个消息分发器为并发安全的生产者和消费者模型,生产者可以是任意类型,消费者必须是 Message 接口的实现。
|
||||
// 生产者可以通过 Put 方法并发安全地将消息放入消息分发器,消息执行过程不会阻塞到 Put 方法,同时允许在 Start 方法之前调用 Put 方法。
|
||||
// 在执行 Start 方法后,消息分发器会阻塞地从消息缓冲区中读取消息,然后执行消息处理器,消息处理器的执行过程不会阻塞到消息的生产。
|
||||
//
|
||||
// 为了保证消息不丢失,内部采用了 buffer.RingUnbounded 作为缓冲区实现,并且消息分发器不提供 Close 方法。
|
||||
// 如果需要关闭消息分发器,可以通过 Expel 方法设置驱逐计划,当消息分发器中没有任何消息时,将会被释放。
|
||||
// 同时,也可以使用 UnExpel 方法取消驱逐计划。
|
||||
//
|
||||
// 为什么提供 Expel 和 UnExpel 方法:
|
||||
// - 在连接断开时,当需要执行一系列消息处理时,如果直接关闭消息分发器,可能会导致消息丢失。所以提供了 Expel 方法,可以在消息处理完成后再关闭消息分发器。
|
||||
// - 当消息还未处理完成时连接重连,如果没有取消驱逐计划,可能会导致消息分发器被关闭。所以提供了 UnExpel 方法,可以在连接重连后取消驱逐计划。
|
||||
type Dispatcher[P Producer, M Message[P]] struct {
|
||||
buf *buffer.RingUnbounded[M]
|
||||
uniques *haxmap.Map[string, struct{}]
|
||||
|
@ -42,6 +58,7 @@ type Dispatcher[P Producer, M Message[P]] struct {
|
|||
}
|
||||
|
||||
// SetProducerDoneHandler 设置特定生产者所有消息处理完成时的回调函数
|
||||
// - 如果 handler 为 nil,则会删除该生产者的回调函数
|
||||
func (d *Dispatcher[P, M]) SetProducerDoneHandler(p P, handler func(p P, dispatcher *Dispatcher[P, M])) *Dispatcher[P, M] {
|
||||
d.lock.Lock()
|
||||
if handler == nil {
|
||||
|
|
|
@ -3,6 +3,7 @@ package dispatcher_test
|
|||
import (
|
||||
"github.com/kercylan98/minotaur/server/internal/dispatcher"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -16,33 +17,174 @@ func (m *TestMessage) GetProducer() string {
|
|||
return m.producer
|
||||
}
|
||||
|
||||
func TestDispatcher_PutStartClose(t *testing.T) {
|
||||
// 写入完成后,关闭分发器再开始分发,确保消息不会丢失
|
||||
w := new(sync.WaitGroup)
|
||||
cw := new(sync.WaitGroup)
|
||||
cw.Add(1)
|
||||
d := dispatcher.NewDispatcher[string, *TestMessage](1024*16, "test", func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
|
||||
t.Log(message)
|
||||
w.Done()
|
||||
}).SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
|
||||
t.Log("closed")
|
||||
cw.Done()
|
||||
})
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
w.Add(1)
|
||||
d.Put(&TestMessage{
|
||||
producer: "test",
|
||||
v: i,
|
||||
})
|
||||
func TestNewDispatcher(t *testing.T) {
|
||||
var cases = []struct {
|
||||
name string
|
||||
bufferSize int
|
||||
handler dispatcher.Handler[string, *TestMessage]
|
||||
shouldPanic bool
|
||||
}{
|
||||
{name: "TestNewDispatcher_BufferSize0AndHandlerNil", bufferSize: 0, handler: nil, shouldPanic: true},
|
||||
{name: "TestNewDispatcher_BufferSize0AndHandlerNotNil", bufferSize: 0, handler: func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {}, shouldPanic: true},
|
||||
{name: "TestNewDispatcher_BufferSize1AndHandlerNil", bufferSize: 1, handler: nil, shouldPanic: true},
|
||||
{name: "TestNewDispatcher_BufferSize1AndHandlerNotNil", bufferSize: 1, handler: func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {}, shouldPanic: false},
|
||||
}
|
||||
|
||||
d.Start()
|
||||
d.Expel()
|
||||
d.UnExpel()
|
||||
w.Wait()
|
||||
time.Sleep(time.Second)
|
||||
d.Expel()
|
||||
cw.Wait()
|
||||
t.Log("done")
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil && !c.shouldPanic {
|
||||
t.Errorf("NewDispatcher() should not panic, but panic: %v", r)
|
||||
}
|
||||
}()
|
||||
dispatcher.NewDispatcher(c.bufferSize, c.name, c.handler)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_SetProducerDoneHandler(t *testing.T) {
|
||||
var cases = []struct {
|
||||
name string
|
||||
producer string
|
||||
messageFinish *atomic.Bool
|
||||
cancel bool
|
||||
}{
|
||||
{name: "TestDispatcher_SetProducerDoneHandlerNotCancel", producer: "producer", cancel: false},
|
||||
{name: "TestDispatcher_SetProducerDoneHandlerCancel", producer: "producer", cancel: true},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
c.messageFinish = &atomic.Bool{}
|
||||
w := new(sync.WaitGroup)
|
||||
d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
|
||||
w.Done()
|
||||
})
|
||||
d.SetProducerDoneHandler("producer", func(p string, dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
|
||||
c.messageFinish.Store(true)
|
||||
})
|
||||
if c.cancel {
|
||||
d.SetProducerDoneHandler("producer", nil)
|
||||
}
|
||||
d.Put(&TestMessage{producer: "producer"})
|
||||
w.Add(1)
|
||||
d.Start()
|
||||
w.Wait()
|
||||
if c.messageFinish.Load() && c.cancel {
|
||||
t.Errorf("%s should cancel, but not", c.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_SetClosedHandler(t *testing.T) {
|
||||
var cases = []struct {
|
||||
name string
|
||||
handlerFinishMsgCount *atomic.Int64
|
||||
msgTime time.Duration
|
||||
msgCount int
|
||||
}{
|
||||
{name: "TestDispatcher_SetClosedHandler_Normal", msgTime: 0, msgCount: 1},
|
||||
{name: "TestDispatcher_SetClosedHandler_MessageCount1024", msgTime: 0, msgCount: 1024},
|
||||
{name: "TestDispatcher_SetClosedHandler_MessageTime1sMessageCount3", msgTime: 1 * time.Second, msgCount: 3},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
c.handlerFinishMsgCount = &atomic.Int64{}
|
||||
w := new(sync.WaitGroup)
|
||||
d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
|
||||
time.Sleep(c.msgTime)
|
||||
c.handlerFinishMsgCount.Add(1)
|
||||
})
|
||||
d.SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
|
||||
w.Done()
|
||||
})
|
||||
for i := 0; i < c.msgCount; i++ {
|
||||
d.Put(&TestMessage{producer: "producer"})
|
||||
}
|
||||
w.Add(1)
|
||||
d.Start()
|
||||
d.Expel()
|
||||
w.Wait()
|
||||
if c.handlerFinishMsgCount.Load() != int64(c.msgCount) {
|
||||
t.Errorf("%s should finish %d messages, but finish %d", c.name, c.msgCount, c.handlerFinishMsgCount.Load())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_Expel(t *testing.T) {
|
||||
var cases = []struct {
|
||||
name string
|
||||
handlerFinishMsgCount *atomic.Int64
|
||||
msgTime time.Duration
|
||||
msgCount int
|
||||
}{
|
||||
{name: "TestDispatcher_Expel_Normal", msgTime: 0, msgCount: 1},
|
||||
{name: "TestDispatcher_Expel_MessageCount1024", msgTime: 0, msgCount: 1024},
|
||||
{name: "TestDispatcher_Expel_MessageTime1sMessageCount3", msgTime: 1 * time.Second, msgCount: 3},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
c.handlerFinishMsgCount = &atomic.Int64{}
|
||||
w := new(sync.WaitGroup)
|
||||
d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
|
||||
time.Sleep(c.msgTime)
|
||||
c.handlerFinishMsgCount.Add(1)
|
||||
})
|
||||
d.SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
|
||||
w.Done()
|
||||
})
|
||||
for i := 0; i < c.msgCount; i++ {
|
||||
d.Put(&TestMessage{producer: "producer"})
|
||||
}
|
||||
w.Add(1)
|
||||
d.Start()
|
||||
d.Expel()
|
||||
w.Wait()
|
||||
if c.handlerFinishMsgCount.Load() != int64(c.msgCount) {
|
||||
t.Errorf("%s should finish %d messages, but finish %d", c.name, c.msgCount, c.handlerFinishMsgCount.Load())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDispatcher_UnExpel(t *testing.T) {
|
||||
var cases = []struct {
|
||||
name string
|
||||
closed *atomic.Bool
|
||||
isUnExpel bool
|
||||
expect bool
|
||||
}{
|
||||
{name: "TestDispatcher_UnExpel_Normal", isUnExpel: true, expect: false},
|
||||
{name: "TestDispatcher_UnExpel_NotExpel", isUnExpel: false, expect: true},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
c.closed = &atomic.Bool{}
|
||||
w := new(sync.WaitGroup)
|
||||
d := dispatcher.NewDispatcher(1024, c.name, func(dispatcher *dispatcher.Dispatcher[string, *TestMessage], message *TestMessage) {
|
||||
w.Done()
|
||||
})
|
||||
d.SetClosedHandler(func(dispatcher *dispatcher.Dispatcher[string, *TestMessage]) {
|
||||
c.closed.Store(true)
|
||||
})
|
||||
d.Put(&TestMessage{producer: "producer"})
|
||||
w.Add(1)
|
||||
if c.isUnExpel {
|
||||
d.Expel()
|
||||
d.UnExpel()
|
||||
} else {
|
||||
d.Expel()
|
||||
}
|
||||
d.Start()
|
||||
w.Wait()
|
||||
if c.closed.Load() != c.expect {
|
||||
t.Errorf("%s should %v, but %v", c.name, c.expect, c.closed.Load())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ func New(network Network, options ...Option) *Server {
|
|||
connWriteBufferSize: DefaultConnWriteBufferSize,
|
||||
dispatcherBufferSize: DefaultDispatcherBufferSize,
|
||||
},
|
||||
hub: &hub{},
|
||||
connMgr: &connMgr{},
|
||||
option: &option{},
|
||||
network: network,
|
||||
closeChannel: make(chan struct{}, 1),
|
||||
|
@ -68,7 +68,7 @@ type Server struct {
|
|||
*event // 事件
|
||||
*runtime // 运行时
|
||||
*option // 可选项
|
||||
*hub // 连接集合
|
||||
*connMgr // 连接集合
|
||||
dispatcherMgr *dispatcher.Manager[string, *Message] // 消息分发器管理器
|
||||
ginServer *gin.Engine // HTTP模式下的路由器
|
||||
httpServer *http.Server // HTTP模式下的服务器
|
||||
|
@ -76,7 +76,7 @@ type Server struct {
|
|||
gServer *gNet // TCP或UDP模式下的服务器
|
||||
multiple *MultipleServer // 多服务器模式下的服务器
|
||||
ants *ants.Pool // 协程池
|
||||
messagePool *hub.Pool[*Message] // 消息池
|
||||
messagePool *hub.ObjectPool[*Message] // 消息池
|
||||
ctx context.Context // 上下文
|
||||
cancel context.CancelFunc // 停止上下文
|
||||
systemSignal chan os.Signal // 系统信号
|
||||
|
@ -100,7 +100,7 @@ func (srv *Server) preCheckAndAdaptation(addr string) (startState <-chan error,
|
|||
kcp.SystemTimedSched.Close()
|
||||
}
|
||||
|
||||
srv.hub.run(srv.ctx)
|
||||
srv.connMgr.run(srv.ctx)
|
||||
return srv.network.adaptation(srv), nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue